From 7ee0902f86c61c010f3760c53e388efbd38783dd Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 26 May 2020 13:42:13 +0800 Subject: [PATCH 01/78] Code migration Start (#1) * Init commit: Code migration Start * Add loop_state.cc/h * Add ComputeDAG basic test --- CMakeLists.txt | 1 + src/ansor/compute_dag.cc | 1245 +++++++++++++++++++++++++++ src/ansor/compute_dag.h | 161 ++++ src/ansor/expr_hasher.h | 97 +++ src/ansor/loop_state.cc | 1729 ++++++++++++++++++++++++++++++++++++++ src/ansor/loop_state.h | 732 ++++++++++++++++ src/ansor/utils.cc | 102 +++ src/ansor/utils.h | 482 +++++++++++ tests/cpp/ansor_test.cc | 95 +++ 9 files changed, 4644 insertions(+) create mode 100644 src/ansor/compute_dag.cc create mode 100644 src/ansor/compute_dag.h create mode 100644 src/ansor/expr_hasher.h create mode 100644 src/ansor/loop_state.cc create mode 100644 src/ansor/loop_state.h create mode 100644 src/ansor/utils.cc create mode 100644 src/ansor/utils.h create mode 100644 tests/cpp/ansor_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index d7faa8a4b666..5550b5f6b3a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB_RECURSE COMPILER_SRCS + src/ansor/*.cc src/node/*.cc src/ir/*.cc src/arith/*.cc diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc new file mode 100644 index 000000000000..31136985b330 --- /dev/null +++ b/src/ansor/compute_dag.cc @@ -0,0 +1,1245 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "compute_dag.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// #include "loop_state.h" +#include "utils.h" +// #include "../relay/pass/kernel_layout_transform.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +TVM_REGISTER_NODE_TYPE(ComputeDAGNode); + +template +using OperationMap = AccessAnalyzerNode::OperationMap; + +using OperationSet = std::unordered_set; + +// Topo-sort ops from tensors according to their read-write relations. +// Results are stored in ops +void TopoSortOps(const Array& tensors, std::vector* ops) { + std::unordered_map degree; + std::unordered_map > edge_set; + std::unordered_map priority; + std::unordered_set visited; + + // traverse to build edge_set and count degree + std::vector stack; + stack.reserve(tensors.size()); + for (const auto& x : tensors) { + stack.push_back(x->op.operator->()); + } + + int ct = 0; + while (!stack.empty()) { + const te::OperationNode* op = stack.back(); + stack.pop_back(); + if (visited.count(op)) { + continue; + } + + priority[op] = ct; + ct++; + visited.insert(op); + + if (op->IsInstance()) { + degree[op] = 0; + } else if (auto cop = GetRef(op).as()) { + const Array& input_tensors = cop->InputTensors(); + degree[op] = input_tensors.size(); + for (const auto& ten : input_tensors) { + edge_set[ten->op.operator->()].push_back(op); + stack.push_back(ten->op.operator->()); + } + } else { + LOG(FATAL) << "Unsupported op " << GetRef(op); + } + } + + // topo sort + ops->clear(); + + using Item = std::pair; + auto cmp = [](const Item& left, const Item& right) { + return left.second < right.second; + }; + std::priority_queue, decltype(cmp)> queue(cmp); + for (const auto& iter : degree) { + if (iter.second == 0) { + queue.push(Item(iter.first, priority[iter.first])); + } + } + + ops->reserve(degree.size()); + while (!queue.empty()) { + Item item = queue.top(); + queue.pop(); + ops->push_back(GetRef(item.first)); + for (const auto& dst : edge_set[item.first]) { + degree[dst] -= 1; + if (degree[dst] == 0) { + queue.push(Item(dst, priority[dst])); + } + } + } +} + +// Extract all tensor accesses in an expr +class TensorAccessExtractor : public StmtExprVisitor { + public: + void Extract(PrimExpr expr) { + this->VisitExpr(expr); + } + + void VisitExpr_(const CallNode *op) final { + if (op->call_type == CallNode::CallType::Halide) { + buf_accesses[Downcast(op->func)].emplace_back( + op->args.begin(), op->args.end()); + } + if (op->name == tir::intrinsic::tvm_if_then_else) { + has_branch = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode* op) final { + has_branch = true; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const SelectNode* op) final { + has_branch = true; + StmtExprVisitor::VisitExpr_(op); + } + + OperationMap > > buf_accesses; + bool has_branch{false}; +}; + +// Returns whether the expr equals to the var with a const shift +bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { + if (auto pv = expr.as()) { + return pv == var.get(); + } else if (auto padd = expr.as()) { + return ((padd->a.get() == var.get() && padd->b->IsInstance()) || + (padd->b.get() == var.get() && padd->a->IsInstance())); + } else if (auto psub = expr.as()) { + return ((psub->a.get() == var.get() && psub->b->IsInstance()) || + (psub->b.get() == var.get() && psub->a->IsInstance())); + } else { + return false; + } +} + +// Return whether the access is injective +bool IsInjective(const te::Operation& op, const std::vector& index, + bool* axis_missing, bool* axis_duplicated, bool* same_order) { + auto cop = op.as(); + if (cop == nullptr) { return false; } + + std::vector index_to_var_idx; + std::vector var_idx_ct(cop->axis.size(), 0); + + for (const auto& expr : index) { + if (!is_const(expr)) { + bool found = false; + for (size_t i = 0; i < cop->axis.size(); ++i) { + if (IsConstShiftEqual(cop->axis[i]->var, expr)) { + index_to_var_idx.push_back(i); + var_idx_ct[i]++; + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + *axis_missing = false; // Some axes are missing + *axis_duplicated = false; // Some axes appear more than once + *same_order = true; // The axis order is the same as op->axis + for (int ct : var_idx_ct) { + if (ct == 0) { + *axis_missing = true; + } else if (ct > 1) { + *axis_duplicated = true; + } + } + for (size_t i = 1; i < index_to_var_idx.size(); ++i) { + if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { + *same_order = false; + break; + } + } + + return true; +} + +// Gather all VarNodes in an expr +static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { + PostOrderVisit(expr, [&vars](const ObjectRef &node) { + if (const VarNode* op = node.as()) { + vars->insert(op); + } + }); +} + +// Check whether an expr has expensive operations (e.g. exp) +static bool HasExpensiveOp(const PrimExpr& expr) { + bool found = false; + PostOrderVisit(expr, [&found](const ObjectRef &node) { + if (const CallNode* op = node.as()) { + if (op->call_type == CallNode::CallType::PureIntrinsic && op->name == "exp") { + found = true; + } + } + }); + return found; +} + +AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { + auto node = make_object(); + OperationMap has_branch; + + // get all ops + TopoSortOps(tensors, &node->ops_topo_order); + + // build read & write access map + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->read_from[op] = OperationMap > >(); + } else if (auto cop = op.as()) { + TensorAccessExtractor extractor; + for (const auto& exp : cop->body) { + extractor.Extract(exp); + } + + for (const auto& iter : extractor.buf_accesses) { + std::vector >& accesses = node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); + } + + node->read_from[op] = std::move(extractor.buf_accesses); + has_branch[op] = extractor.has_branch; + } else { + LOG(FATAL) << "Invalid op: " << op; + } + } + + // do some static analysis + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->is_injective[op] = true; + node->needs_multi_level_tiling[op] = false; + node->is_strict_inlineable[op] = false; + node->is_output[op] = false; + } else if (auto pop = op.as()) { + // check whether is element-wise and strict-inlineable (see definition in compute_dag.h) + bool is_injective = true; + bool is_strict_inlineable = true; + + bool axis_missing, axis_duplicated, same_order; + for (const auto& pair : node->read_from[op]) { + const std::vector >& access = pair.second; + for (const auto& index : access) { + if (!IsInjective(op, index, &axis_missing, &axis_duplicated, &same_order)) { + is_injective = false; + is_strict_inlineable = false; + break; + } + if (!same_order || axis_duplicated) { // do not strictly inline transpose + is_strict_inlineable = false; + } + } + if (!is_injective) { break; } + } + if (has_branch[op]) { + is_strict_inlineable = false; + } + + // don't strictly inline expensive op (e.g. exp) + bool has_expensive_op = false; + for (const auto& expr : pop->body) { + has_expensive_op |= HasExpensiveOp(expr); + } + + node->is_injective[op] = is_injective; + node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + + // check whether the op needs multi-level tiling (see definition in compute_dag.h) + bool needs_multi_level_tiling = false; + int n_missing = 0; + + for (const auto& pair : node->read_from[op]) { + const std::vector > &access = pair.second; + std::unordered_set vars; + for (const std::vector &indices : access) { + for (const PrimExpr& expr : indices) { + GatherVars(expr, &vars); + } + } + bool missing = false; + for (const auto& axis : pop->axis) { + if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { + missing = true; + } + } + if (missing) { + n_missing++; + } + + if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { + needs_multi_level_tiling = true; + break; + } + } + + node->needs_multi_level_tiling[op] = needs_multi_level_tiling; + + // check whether is output + node->is_output[op] = node->read_by[op].empty(); + } else { + LOG(FATAL) << "Invalid op" << op; + } + } + + return AccessAnalyzer(node); +} + +bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation &op) const { + return operator->()->needs_multi_level_tiling.at(op); +} + +bool AccessAnalyzer::IsOutput(const te::Operation& op) const { + return operator->()->is_output.at(op); +} + +bool AccessAnalyzer::IsInjective(const te::Operation& op) const { + return operator->()->is_injective.at(op); +} + +bool AccessAnalyzer::IsStrictInlineable(const te::Operation &op) const { + return operator->()->is_strict_inlineable.at(op); +} + +void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, + OperationSet* producers) const { + producers->clear(); + for (const auto& iter : operator->()->read_from.at(op)) { + producers->insert(iter.first); + } +} + +// void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, +// OperationSet* consumers) const { +// OperationSet inlined_ops; + +// for (const auto& stage : state->stages) { +// if (stage->compute_at == kInlined) { +// inlined_ops.insert(stage->op); +// } +// } +// std::function collect; + +// collect = [this, &collect, &inlined_ops, &consumers](const Operation& op) { +// for (const auto& iter : operator->()->read_by.at(op)) { +// if (inlined_ops.count(iter.first)) { +// collect(iter.first); +// } else { +// consumers->insert(iter.first); +// } +// } +// }; + +// consumers->clear(); +// collect(op); +// } + +bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + +bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const { + te::Operation cur_op = op; + while (cur_op != target_op) { + const AccessAnalyzerNode::OperationMap > >& map = + operator->()->read_by.at(cur_op); + + if (map.size() != 1) { + return false; + } + te::Operation next_op = map.begin()->first; + + // Check condition 1: has the same output size + auto p_cur = cur_op.as(); + auto p_next = next_op.as(); + if (p_cur == nullptr || p_next == nullptr) { + return false; + } + + Array output_shape = p_cur->output_shape(0); + for (int i = 1; i < p_cur->num_outputs(); ++i) { + if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { + return false; + } + } + for (int i = 0; i < p_next->num_outputs(); ++i) { + if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { + return false; + } + } + + // Check condition 2: read is elementwise + const std::vector > reads = map.begin()->second; + bool is_injective, axis_missing, axis_duplicated, same_order; + for (const auto& read : reads) { + is_injective = ::tvm::ansor::IsInjective( + next_op, read, &axis_missing, &axis_duplicated, &same_order); + if (!is_injective || axis_missing || axis_duplicated || !same_order) { + return false; + } + } + + cur_op = std::move(next_op); + } + return true; +} + +// Estimate number of float operations in an expression +class FlopEstimator: public ExprFunctor { + public: + double EstimateFlop(const Array& ops) { + double ret = 0; + for (const auto& op : ops) { + if (auto pop = op.as()) { + double num_element = AxisLengthProd(pop->axis); + if (num_element == -1) { + fail = true; + break; + } + double op_per_element = 0; + for (const auto& x : pop->body) { + op_per_element += VisitExpr(x); + } + ret += num_element * op_per_element; + } else if (op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op type " << op; + } + } + + return fail ? -1 : ret; + } + + double VisitExpr_(const ReduceNode* op) final { + uint64_t num_iter = 1; + for (const auto& x : op->axis) { + if (auto imm = x->dom->extent.as()) { + num_iter *= imm->value; + } else { + fail = true; + num_iter = -1; + } + } + double body_flop = 0; + for (size_t i = 0; i < op->combiner->result.size(); ++i) { + body_flop += VisitExpr(op->combiner->result[i]); + body_flop += VisitExpr(op->source[i]); + } + return num_iter * body_flop; + } + + double VisitExpr_(const FloatImmNode* op) final { return 0.0; } + double VisitExpr_(const IntImmNode* op) final { return 0.0; } +// double VisitExpr_(const UIntImm* op) final { return 0.0; } + + double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + double VisitExpr_(const VarNode* op) final { return 0.0; } + + double VisitExpr_(const SelectNode* op) final { + return VisitExpr(op->condition) + std::max(VisitExpr(op->true_value), + VisitExpr(op->false_value)); + } + +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); \ + } +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { \ + return 1.0 + VisitExpr(op->a); \ + } + + VisitBinary(AddNode); VisitBinary(SubNode); VisitBinary(MulNode) + VisitBinary(DivNode); VisitBinary(ModNode); VisitBinary(FloorDivNode) + VisitBinary(FloorModNode); VisitBinary(MaxNode); VisitBinary(MinNode); + VisitBinary(EQNode); VisitBinary(NENode); VisitBinary(LTNode); + VisitBinary(LENode); VisitBinary(GTNode); VisitBinary(GENode); + VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); + + double VisitExpr_(const CallNode* op) final { + if (op->call_type == CallNode::CallType::Halide) { + // ignore flops in index expressions + return 0.0; + } + + double ret = 0.0; + for (const auto&x : op->args) { + ret += VisitExpr(x); + } + return ret; + } + + double VisitExprDefault_(const Object* op) final { + fail = true; + return -1.0; + } + + bool fail{false}; +}; + +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { + if (auto pop = stage->op.as()) { + std::vector& axes = (*stage_to_axes)[stage]; + axes.clear(); + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + } else if (stage->op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + +// State ComputeDAG::GetInitState() const { +// return Downcast(operator->()->init_state); +// } + +ComputeDAG ComputeDAGNode::make(Array tensors) { + auto node = make_object(); + FlopEstimator estimator; + + node->tensors = std::move(tensors); + node->access_analyzer = AccessAnalyzerNode::make(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); +// node->init_state = StateNode::make(node->ops); + + return ComputeDAG(node); +} + +ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) { + Array tens; + // Call python function to decode the workload_key and get the I/O tensors + if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { + tens = (*f)(workload_key); + } else { + LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; + } + return ComputeDAGNode::make(std::move(tens)); +} + +void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("access_analyzer", &access_analyzer); +// State s = Downcast(init_state); +// v->Visit("init_state", &s); +} + +// Implemented in multi_stage_policy.cc +// Extract primitive iterators from a nested fused or splitted iterator's name +extern void ExtractOriginalIterators(const std::string& name, std::set* rets); + +// Implemented in loop_state.cc +extern std::string CleanName(const std::string& str); + +std::string BaseName(const std::string& str) { + return str.substr(0, str.rfind("_")); +} + +// class IndexRewriter : public ExprMutator { +// public: +// IndexRewriter(const OperationMap >& placeholder_new_names, +// const OperationMap >& placeholder_new_shapes): +// placeholder_new_names_(placeholder_new_names), +// placeholder_new_shapes_(placeholder_new_shapes) {} + +// Expr Mutate_(const Call* op, const Expr& e) { +// Expr op_ = IRMutator::Mutate_(op, e); + +// const Call* call = op_.as(); + +// if (call->call_type == Call::CallType::Halide) { +// Tensor t = Downcast(call->func).output(call->value_index); +// auto it = placeholder_new_names_.find(t->op); +// if (it != placeholder_new_names_.end()) { +// const std::vector& new_names = it->second; +// const Array& new_shape = placeholder_new_shapes_.at(t->op); +// std::unordered_map name_to_arg; +// for (const auto& arg : call->args) { +// std::string axis_name; +// if (const auto* pimm = arg.as()) { +// CHECK_EQ(pimm->value, 0); +// axis_name = "IntImm"; +// } else { +// axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); +// CHECK_EQ(name_to_arg.count(axis_name), 0); +// name_to_arg[axis_name] = arg; +// } +// } + +// std::unordered_map div_factors; +// std::vector r_new_args; +// for (int i = new_names.size() - 1; i >= 0; --i) { +// auto ori_iter_name = new_names[i]; +// auto name_it = name_to_arg.find(ori_iter_name); +// CHECK(name_it != name_to_arg.end()); +// Expr ori_arg = name_it->second; + +// Expr mod_factor = new_shape[i]; + +// Expr div_factor = 1; +// if (div_factors.count(ori_iter_name)) { +// div_factor = div_factors[ori_iter_name]; +// } +// div_factors[ori_iter_name] = div_factor * new_shape[i]; + +// Expr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); + +// r_new_args.push_back(new_arg); +// } + +// Array new_args(std::make_move_iterator(r_new_args.rbegin()), +// std::make_move_iterator(r_new_args.rend())); + +// return Call::make(call->type, call->name, new_args, call->call_type, +// call->func, call->value_index); +// } +// } +// return op_; +// } + +// private: +// const OperationMap >& placeholder_new_names_; +// const OperationMap >& placeholder_new_shapes_; +// }; + +// // TODO(minminsun): spill out new functions +// void ComputeDAG::RewriteLayout( +// const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { +// ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); +// const State& state = ReplayAndInferBound(transform_steps); + +// OperationMap > placeholder_new_names; +// OperationMap > placeholder_new_shapes; +// int stage_id = -1; +// for (const auto& stage : state->stages) { +// stage_id += 1; +// const Operation& op = stage->op; +// if (op->IsInstance()) { +// const Map& attrs = op->attrs; +// if (attrs.count(_layout_free_placeholders_key)) { +// const ObjectRef& attr_value = attrs[_layout_free_placeholders_key]; +// Array placeholders = Downcast>(attr_value); +// for (auto& placeholder : placeholders) { +// const auto placeholder_op = placeholder->op; + +// // Check whether this placeholder has already been handled +// if (placeholder_new_names.count(placeholder_op)) { +// continue; +// } + +// // skip the op that is not direct consumer of this placeholder, +// // mostly due to cache read/write. +// bool direct_consumer = false; +// for (auto& t : op->InputTensors()) { +// if (t->op == placeholder_op) { +// direct_consumer = true; +// break; +// } +// } +// if (!direct_consumer) { +// continue; +// } + +// std::set placeholder_axis_names; +// TensorAccessExtractor extractor; +// for (const auto& exp : op.as()->body) { +// extractor.Extract(exp); +// } +// bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || +// layout_rewrite_level == kBothRewrite); +// bool rewrite_body = (layout_rewrite_level == kComputeRewrite || +// layout_rewrite_level == kBothRewrite); +// std::ostringstream os; + +// uint i = 0; +// if (extractor.buf_accesses.count(placeholder_op)) { +// for (const auto& ev : extractor.buf_accesses[placeholder_op]) { +// for (const auto& e : ev) { +// // TODO(minminsun): check whether the extents match the shape of placeholder +// std::string axis_name; +// if (const auto* pimm = e.as()) { +// CHECK_EQ(pimm->value, 0); +// // CHECK_EQ(placeholder->shape[i].as()->value, 1); +// axis_name = "IntImm"; +// } else { +// axis_name = BaseName(CleanName(Downcast(e)->name_hint)); +// } + +// placeholder_axis_names.insert(axis_name); +// if (rewrite_placeholder) { +// os << placeholder->shape[i++] << axis_name; +// } +// } +// } + +// if (rewrite_placeholder) { +// CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); +// std::string ori_layout = os.str(); +// os.str(""); +// ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); +// } +// } + +// std::vector stage_iters; + +// auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); +// int attach_pos = -1; +// size_t iters_before_attach = 0; +// if (attach_it != state->attach_map->stage_to_attach_iter.end()) { +// auto attach = attach_it->second; +// const auto& attach_stage = state->stages[attach.first]; +// attach_pos = attach.second; +// stage_iters.insert(stage_iters.end(), +// attach_stage->iters.begin(), +// attach_stage->iters.begin() + attach_pos + 1); +// } + +// stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); + +// std::vector iters; +// for (size_t i = 0; i < stage_iters.size(); ++i) { +// const auto& iter = stage_iters[i]; +// if (iter->ori_iters.empty()) { +// iters.push_back(iter); +// } else { +// for (const Iterator& ori_iter : iter->ori_iters) { +// iters.push_back(ori_iter); +// } +// } +// if (static_cast(i) == attach_pos) { +// iters_before_attach = iters.size(); +// } +// } + +// std::vector new_names; +// Array new_shape; +// std::vector new_axis_names; +// for (const Iterator& iter : iters) { +// std::set ori_iter_names; +// ExtractOriginalIterators(iter->name, &ori_iter_names); +// // fused iters have been replaced with iter->ori_iters. +// // So there should be only one ori iter name extracted from iter->name. +// CHECK_EQ(ori_iter_names.size(), 1); +// auto ori_iter_name = BaseName(*ori_iter_names.begin()); +// new_axis_names.push_back(ori_iter_name); +// } +// for (size_t i = 0; i < new_axis_names.size(); ++i) { +// auto iter = iters[i]; +// std::string ori_iter_name; +// if (i < iters_before_attach) { +// ori_iter_name = new_axis_names[i + iters_before_attach]; +// } else { +// ori_iter_name = new_axis_names[i]; +// } +// if (placeholder_axis_names.count(ori_iter_name)) { +// os << iter->range->extent << ori_iter_name; +// new_names.push_back(ori_iter_name); +// new_shape.push_back(iter->range->extent); +// } +// } +// std::string new_layout = os.str(); +// os.str(""); +// ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); +// placeholder_new_names[placeholder_op] = new_names; +// placeholder_new_shapes[placeholder_op] = new_shape; + +// Array old_ops = pdag->ops; +// ArrayNode* pops = pdag->ops.CopyOnWrite(); + +// // Create new placeholder +// Operation new_placeholder_op; +// if (rewrite_placeholder) { +// new_placeholder_op = +// te::PlaceholderOpNode::make(placeholder_op->name, +// new_shape, +// placeholder_op.as()->dtype); +// } else { +// new_placeholder_op = placeholder_op; +// } + +// Operation new_compute_op, old_compute_op; +// if (rewrite_body) { +// Array new_body; +// IndexRewriter index_rewriter(placeholder_new_names, +// placeholder_new_shapes); +// for (auto& op : old_ops) { +// if (auto* pop = op.as()) { +// bool need_update = false; +// for (auto& t : op->InputTensors()) { +// if (t->op == placeholder_op) { +// need_update = true; +// break; +// } +// } +// if (need_update) { +// for (auto& body : pop->body) { +// new_body.push_back(index_rewriter.Mutate(body)); +// } +// old_compute_op = op; +// CHECK(!new_compute_op.defined()); +// new_compute_op = ComputeOpNode::make( +// pop->name, pop->tag, pop->attrs, pop->axis, new_body); +// } +// } +// } +// } + +// // construct the map from old_op to new_op +// std::unordered_map updated_ops; +// for (size_t i = 0; i < old_ops.size(); ++i) { +// auto old_op = old_ops[i]; +// if (rewrite_placeholder && old_op == placeholder_op) { +// pops->data[i] = new_placeholder_op; +// updated_ops[placeholder_op] = new_placeholder_op; +// } else if (rewrite_body && old_op == old_compute_op) { +// pops->data[i] = new_compute_op; +// updated_ops[old_compute_op] = new_compute_op; +// } else { +// pops->data[i] = old_op; +// } +// } + +// // Because ops is sorted in topo-order, only do one pass linear scan here. +// for (size_t i = 0; i < pops->data.size(); ++i) { +// auto old_op = Downcast(pops->data[i]); +// if (auto* pop = old_op.as()) { +// auto inputs = pop->InputTensors(); +// std::unordered_map rmap; +// for (auto input : inputs) { +// auto it = updated_ops.find(input->op); +// Operation new_op; +// while (it != updated_ops.end()) { +// new_op = it->second; +// it = updated_ops.find(new_op); +// } +// if (new_op.defined()) { +// int index = input->value_index; +// rmap[input] = new_op.output(index); +// } +// } +// if (!rmap.empty()) { +// Operation new_op = pop->ReplaceInputs(old_op, rmap); +// updated_ops[old_op] = new_op; +// pops->data[i] = new_op; +// } +// } +// } + +// pdag->init_state = StateNode::make(pdag->ops); + +// Array old_tensors = pdag->tensors; +// ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); + +// for (size_t i = 0; i < old_tensors.size(); ++i) { +// const auto& old_tensor = old_tensors[i]; +// auto it = updated_ops.find(old_tensor->op); +// Operation new_op; +// while (it != updated_ops.end()) { +// new_op = it->second; +// it = updated_ops.find(new_op); +// } +// if (new_op.defined()) { +// if (layout_rewrite_level == kBothRewrite) { +// auto index = old_tensor->value_index; +// ptensors->data[i] = new_op.output(index); +// } else if (layout_rewrite_level == kComputeRewrite) { +// TensorNode* old_tensor_node = const_cast(old_tensor.as()); +// old_tensor_node->op = new_op; +// } +// } +// } +// } // end for placeholder +// } +// } +// } // end for stage +// } + +std::pair > ComputeDAG::ApplySteps( + const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level) const { + std::vector stages; + StageToAxesMap stage_to_axes; + if (layout_rewrite_level != kNoRewrite && !transform_steps.empty()) { + ComputeDAG new_dag = *this; + new_dag.RewriteLayout(transform_steps, layout_rewrite_level); + return new_dag.ReplaySteps(transform_steps, &stages, &stage_to_axes); + } else { + return ReplaySteps(transform_steps, &stages, &stage_to_axes); + } +} + +// std::string ComputeDAG::PrintStepsAsPython( +// const std::vector& transform_steps) const { +// std::vector stages; +// StageToAxesMap stage_to_axes; +// Array ops; +// for (const auto& op : operator->()->ops) { +// if (!op->IsInstance()) { +// ops.push_back(op); +// } +// } +// te::Schedule schedule = te::create_schedule({ops.back()}); + +// // init axes +// for (const auto& x : operator->()->ops) { +// const te::Stage& stage = schedule.operator[](x); +// stages.push_back(stage); +// UpdateStageAxis(stage, &stage_to_axes); +// } + +// std::stringstream ss; + +// for (const auto& stage : stages) { +// if (stage->op->IsInstance()) { +// for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { +// ss << stage->leaf_iter_vars[i]->var->name_hint; +// if (i != stage->leaf_iter_vars.size() - 1) { +// ss << ", "; +// } +// } +// ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" +// << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; +// } +// } + +// for (const auto& step : transform_steps) { +// ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); +// } + +// return ss.str(); +// } + +// State ComputeDAG::ReplayAndInferBound(const std::vector& transform_steps) const { +// State ret_state = GetInitState(); +// StateNode* pstate = ret_state.CopyOnWrite(); +// pstate->transform_steps = transform_steps; +// ret_state.DoSteps(transform_steps, *this); + +// InferBoundCommon(pstate); + +// return ret_state; +// } + +// State ComputeDAG::InferBound(const State& state) const { +// State ret_state = state; +// StateNode* pstate = ret_state.CopyOnWrite(); + +// InferBoundCommon(pstate); + +// return ret_state; +// } + +// void ComputeDAG::InferBound(std::vector* states) const { +// std::vector out_states(states->size(), State()); + +// auto worker_func = [&states, &out_states, this](int idx) { +// try { +// out_states[idx] = this->InferBound((*states)[idx]); +// } catch (dmlc::Error &e) { +// LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] +// << "\n" << e.what() << std::endl; +// } +// }; + +// // Lower states in parallel +// ThreadPool& pool = ThreadPool::Global(); +// pool.BeginBatch(states->size()); +// for (size_t i = 0; i < states->size(); ++i) { +// pool.Enqueue(worker_func, i); +// } +// pool.WaitBatch(); + +// *states = std::move(out_states); +// } + +void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, + ComputeDAG *task_dag) const { + std::vector stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array old_tensors; + + std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, &stage_to_axes); + + Array new_tensors; + for (auto stage : sch->stages) { + if (stage->op->IsInstance() || + stage->is_output) { + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + new_tensors.push_back(stage->op.output(i)); + } + } + } + + *task_dag = ComputeDAGNode::make(new_tensors); +} + + +// void ComputeDAG::InferBoundCommon(StateNode* pstate) const { +// std::vector stages; +// StageToAxesMap stage_to_axes; +// te::Schedule sch; +// Array tensors; +// Map bounds; + +// std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); +// sch = sch.normalize(); +// bounds = schedule::InferBound(sch); + +// for (size_t i = 0; i < pstate->stages.size(); ++i) { +// const Stage& stage = pstate->stages[i]; + +// if (stage->compute_at == kInlined) { +// continue; +// } + +// std::vector new_iters; +// new_iters.reserve(stage->iters.size()); +// for (size_t j = 0; j < stage->iters.size(); ++j) { +// const Iterator& iter = stage->iters[j]; +// const IterVar& axis = stage_to_axes.at(stages[i])[j]; + +// auto find_res = bounds.find(axis); +// if (find_res != bounds.end()) { +// new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, +// iter->iter_type, iter->annotation, +// &iter->ori_iters)); +// } else { +// LOG(FATAL) << "Infer bound fails"; +// } +// } + +// pstate->stages[i] = StageNode::make(stage->op, stage->op_type, +// std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, +// stage->storage_offset); +// } +// } + +// std::pair > ComputeDAG::ReplaySteps( +// const std::vector &transform_steps, +// std::vector *stages, +// StageToAxesMap *stage_to_axes) const { +// std::vector ops; +// for (const auto& op : operator->()->ops) { +// if (!op->IsInstance()) { +// ops.push_back(op); +// } +// } + +// te::Schedule schedule = te::create_schedule({ops.back()}); + +// // init axes +// stages->reserve(operator->()->ops.size()); +// for (const auto& x : operator->()->ops) { +// const te::Stage& stage = schedule.operator[](x); +// stages->push_back(stage); +// UpdateStageAxis(stage, stage_to_axes); +// } + +// // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at +// // an splitted axis? + +// // Use complete rate for the study in the paper +// const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); +// double complete_rate = -1.0; +// if (complete_rate_str) { +// complete_rate = std::stod(complete_rate_str); +// } +// size_t ct = 0; + +// // replay history +// for (const auto& step : transform_steps) { +// if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { +// break; +// } + +// if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else { +// LOG(FATAL) << "Invalid Step"; +// } +// } + +// return std::make_pair(schedule, operator->()->tensors); +// } + + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + std::stringstream ss; + + for (const auto& op : node->ops) { + if (op->IsInstance()) { + ss << op->func_name() << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + } else if (auto pop = op.as()) { + for (size_t k = 0; k < pop->body.size(); ++k) { + ss << op->func_name() << "("; + for (size_t i = 0; i < pop->axis.size(); i++) { + ss << pop->axis[i]->var->name_hint; + if (i != pop->axis.size() - 1) { + ss << ", "; + } + } + ss << ")"; + if (pop->body.size() > 1) { + ss << ".v" << k; + } + if (auto preduce = pop->body[k].as()) { + CHECK_LT(k, preduce->combiner->result.size()); + PrimExpr combiner = preduce->combiner->result[k]; + if (combiner->IsInstance()) { + ss << " += " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " max= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " min= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + const auto& select = combiner.as(); + ss << " select(" << select->condition << ", " << select->true_value + << ", " << select->false_value << ")= " + << '(' << preduce->source[0] << ',' << preduce->source[1] << ")\n"; + } else { + LOG(FATAL) << "Unsupported reduction operator" << combiner; + } + } else { + ss << " = " << pop->body[k] << "\n"; + } + } + } else { + LOG(FATAL) << "Invalid op"; + } + } + + p->stream << ss.str(); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + for (const auto& op : node->ops_topo_order) { + p->stream << op << std::endl; + p->stream << "is_injective:\t" << node->is_injective.at(op) << "\t\t"; + p->stream << "needs_multi_level_tiling:\t" + << node->needs_multi_level_tiling.at(op) << std::endl; + p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) << "\t"; + p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; + p->stream << "Read from:\t"; + for (const auto& pair : node->read_from.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->func_name() << Array(index) << ", "; + } + } + p->stream << "\n"; + p->stream << "Read by:\t"; + for (const auto& pair : node->read_by.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->func_name() << Array(index) << ", "; + } + } + p->stream << "\n"; + p->stream << "==================================================\n"; + } + + AccessAnalyzer ana = GetRef(node); + + p->stream << "ElementwiseMatch: \n"; + for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { + for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { + if (i == j) { continue; } + if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { + p->stream << node->ops_topo_order[i]->func_name() << " -> " + << node->ops_topo_order[j]->func_name() << "\n"; + } + } + } +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h new file mode 100644 index 000000000000..c8da44fee828 --- /dev/null +++ b/src/ansor/compute_dag.h @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/compute_dag.h + * \brief Compute declaration graph and its related analysis tools + */ + +#ifndef TVM_ANSOR_COMPUTE_DAG_H_ +#define TVM_ANSOR_COMPUTE_DAG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +class ComputeDAG; class AccessAnalyzer; +class StateNode; class State; class Step; + +typedef std::unordered_map, ObjectHash, ObjectEqual> + StageToAxesMap; + +// Update StageToAxes Map during replay +void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); + +/*! \brief Read/Write access static analysis result */ +class AccessAnalyzerNode : public Object { + public: + template + using OperationMap = std::unordered_map; + + OperationMap > > > read_from; + OperationMap > > > read_by; + OperationMap is_injective; + OperationMap is_strict_inlineable; + OperationMap needs_multi_level_tiling; + OperationMap is_output; + std::vector ops_topo_order; + + static AccessAnalyzer make(const Array& tensors); + + static constexpr const char* _type_key = "ansor.AccessAnalyzer"; + TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); +}; + +/*! \brief Read/Write access static analysis result */ +class AccessAnalyzer : public ObjectRef { + public: + // read/write access analysis + bool NeedsMultiLevelTiling(const te::Operation& op) const; + bool IsInjective(const te::Operation& op) const; + bool IsStrictInlineable(const te::Operation& op) const; + bool IsOutput(const te::Operation& op) const; + + // Get all producers of an op + void GetProducers(const State& state, const te::Operation& op, + std::unordered_set* producers) const; + // Get all consumers of an op. This func deals with inlined op correctly. + void GetConsumers(const State& state, const te::Operation& op, + std::unordered_set* consumers) const; + // Check whether two ops are elementwise matched + // (e.g. conv2d and relu are elementwise matched) + bool ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const; + + /*! \Note The current implementation follows these (rough) definitions. + * + * Definition of data-reuse : Exists axis in (op->axis union op->reduce_axis) + * and acc in read accesses, such that axis not in acc. + * (e.g. A[i][j] = B[i] has data reuse, while A[i][j] = B[i][j] does not) + * Definition of NeedsMultiLevelTiling: Exists two acc, both of them make this op have data reuse. + * Definition of injective : For all index expressions, they are single axis variable + * plus an optional const shift. + * (e.g. A[i][j] = B[i][j], A[i][j] = B[i+1][j] are injective, while A[i][j] = B[i*j] is not) + * Definition of strict-inlineable : All read accesses are elementwise, and no branch in the body + * (e.g. A[i][j] = B[i][j] + C[i][j] is strict-inlineable, + * while A[i][j] = tvm_if_then_else(B[i][j] > 0, C[i][j], 0) is not + */ + TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); +}; + +/*! \brief Compute declaration graph */ +class ComputeDAGNode : public Object { + public: + Array tensors; // Input and output tensors + Array ops; // All related operations in topo order + double flop_ct; // Number of float operations + AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer + ObjectRef init_state; // initial states + + void VisitAttrs(tvm::AttrVisitor* v); + + static ComputeDAG make(Array tensors); + static ComputeDAG make_by_workload_key(const std::string& workload_key); + + static constexpr const char* _type_key = "ansor.ComputeDAG"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); +}; + +enum LayoutRewriteLevel { + kNoRewrite = 0, // No layout rewrite + kPlaceholderRewrite = 1, // Only rewrite layout of placeholder in the compute dag + kComputeRewrite = 2, // Only rewrite compute body for new layout in the compute dag + kBothRewrite = 3, // Rewrite both placeholder and compute body in the compute dag +}; + +/*! \brief Compute declaration graph */ +class ComputeDAG: public ObjectRef { + public: + // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. + // The return values can be used as arguments to tvm.build or tvm.lower + std::pair > ApplySteps( + const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; + + // Rewrite the the layout of "layout free" placeholders according to transform steps + void RewriteLayout(const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {}; + + // Print transform steps as equivalent python schedule API + std::string PrintStepsAsPython(const std::vector& steps) const; + + // Replay the transform steps and call ir_pass::InferBound to fill correct bound information + State ReplayAndInferBound(const std::vector& transform_steps) const; + + // Fill the correct bound information for a given state + State InferBound(const State& state) const; + + // Fill the correct bound information for a list of given states. + // Return the new states inplace + void InferBound(std::vector* states) const; + + // Replay the transform steps and get the new ops + void ReplayAndGetDAG(const std::vector& steps, ComputeDAG* task_dag) const; + + // Get the init state + State GetInitState() const; + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); + + private: + // Internal common parts for replaying steps + std::pair > ReplaySteps( + const std::vector& transform_steps, std::vector* stages, + StageToAxesMap* stage_to_axes) const {}; + static constexpr const char* _layout_free_placeholders_key = "layout_free_placeholders"; + + // Internal common parts for inferring bound + void InferBoundCommon(StateNode* pstate) const; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_COMPUTE_DAG_H_ diff --git a/src/ansor/expr_hasher.h b/src/ansor/expr_hasher.h new file mode 100644 index 000000000000..1c743ed9a5c4 --- /dev/null +++ b/src/ansor/expr_hasher.h @@ -0,0 +1,97 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file auto_scheduler/expr_hasher.h + * \brief Hash function for a tvm::Expr + */ + +#ifndef TVM_ANSOR_EXPR_HASHER_H_ +#define TVM_ANSOR_EXPR_HASHER_H_ + +#include +#include +#include +#include + +namespace tvm { + +/*! \brief Assign a hash value for a tvm::Expr */ +class ExprHasher: public tir::ExprFunctor { + public: + size_t VisitExpr_(const tir::AddNode* op) final { + return VisitExpr(op->a) + VisitExpr(op->b); + } + + size_t VisitExpr_(const tir::SubNode* op) final { + return VisitExpr(op->a) - VisitExpr(op->b); + } + + size_t VisitExpr_(const tir::MulNode* op) final { + return VisitExpr(op->a) * VisitExpr(op->b); + } + + size_t VisitExpr_(const tir::DivNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) / t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5A); + } + } + + size_t VisitExpr_(const tir::FloorDivNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) / t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5B); + } + } + + size_t VisitExpr_(const tir::ModNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) % t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5C); + } + } + + size_t VisitExpr_(const tir::FloorModNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) % t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5D); + } + } + + size_t VisitExpr_(const tir::CallNode* op) final { + size_t ret = ObjectHash()(op->func); + for (size_t i = 0; i < op->args.size(); ++i) { + ret = dmlc::HashCombine(ret, VisitExpr(op->args[i])); + } + return ret; + } + + size_t VisitExpr_(const tir::VarNode* op) final { + return std::hash()(op); + } + + size_t VisitExpr_(const tir::FloatImmNode* op) final { + return std::hash()(op->value); + } + + size_t VisitExpr_(const tir::IntImmNode* op) final { + return std::hash()(op->value); + } + + size_t VisitExprDefault_(const Object* op) final { + LOG(WARNING) << "Encounter undefined node in ExprHasher: " + << Object::_type_key; + return std::hash()(op); + } +}; + +} // namespace tvm + +#endif // TVM_ANSOR_EXPR_HASHER_H_ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc new file mode 100644 index 000000000000..92157edc463d --- /dev/null +++ b/src/ansor/loop_state.cc @@ -0,0 +1,1729 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "loop_state.h" +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StateNode); + +inline std::string CleanName(const std::string& str) { + // to make the name valid in python code + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + +/********** Reorder **********/ +ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->after_ids = after_ids; + return ReorderStep(node); +} + +void ReorderStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + CHECK_EQ(after_ids.size(), axes.size()); + + std::vector new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); +} + +std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const te::Stage& stage = (*stages)[stage_id]; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Split **********/ +std::vector ApplySplitToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + std::vector outs; + if (inner_to_outer) { + IterVar outer = axes[iter_id], inner; + for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { + IterVar to_split = outer; + stage.split(to_split, lengths[i], &outer, &inner); + outs.push_back(inner); + } + outs.push_back(outer); + } else { + IterVar outer, inner = axes[iter_id]; + for (size_t i = 0; i < lengths.size(); i++) { + IterVar to_split = inner; + stage.split_by_nparts(to_split, lengths[i], &outer, &inner); + outs.push_back(outer); + } + outs.push_back(inner); + } + + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); + if (inner_to_outer) { + new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); + } else { + new_axes.insert(new_axes.end(), outs.begin(), outs.end()); + } + new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return outs; +} + +std::string PrintSplitAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + auto to_split = (*stage_to_axes)[stage][iter_id]; + const auto& func_name = CleanName(stage->op->func_name()); + const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, + iter_id, lengths, inner_to_outer); + + std::stringstream ss; + int size = static_cast(lengths.size()); + if (inner_to_outer) { + for (int i = size - 1; i >= 0; i--) { + ss << CleanName(outs[size - i]->var->name_hint) << ", " + << CleanName(outs[size - i - 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", factor=" << lengths[i] << ")\n"; + to_split = outs[size - i]; + } + } else { + for (int i = 0; i < size; i++) { + ss << CleanName(outs[i]->var->name_hint) << ", " + << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; + to_split = outs[i + 1]; + } + } + + return ss.str(); +} + +SplitStep SplitStepNode::make(int stage_id, int iter_id, + PrimExpr extent, const std::vector& lengths, + bool inner_to_outer) { + auto node = make_object(); + node->stage_id = stage_id; + // Extent can be a unreducible expression in some special cases + if (extent->IsInstance()) { + node->extent = std::move(extent); + } + node->iter_id = iter_id; + node->lengths = lengths; + node->inner_to_outer = inner_to_outer; + return SplitStep(node); +} + +std::vector SplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +std::string SplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +/********** Follow Split **********/ +FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, + int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + return FollowSplitStep(node); +} + +void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // get lengths from src step + lengths->reserve(n_split); + int j = 0; + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j].defined()) { + last_factor *= ps->lengths[j]; + } else { + last_factor = PrimExpr(); + break; + } + } + lengths->push_back(std::move(last_factor)); +} + +std::vector FollowSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +std::string FollowSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids;; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + return FollowFusedSplitStep(node); +} + +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + if (ps->lengths[level].defined() && ret.defined()) { + ret *= ps->lengths[level]; + } else { + return PrimExpr(); + } + } + + return ret; +} + +std::vector FollowFusedSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + +std::string FollowFusedSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + + +/********** Fuse **********/ +FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->fused_ids = fused_ids; + return FuseStep(node); +} + +IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + Array to_fuse; + for (auto i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, + axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return fused_axis; +} + +std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" + << CleanName(stage->op->func_name()) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + +/********** Annotation **********/ +AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->annotation = ann; + return AnnotationStep(node); +} + +void AnnotationStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case kUnroll: stage.unroll(axes[iter_id]); break; + case kVectorize: stage.vectorize(axes[iter_id]); break; + case kParallel: stage.parallel(axes[iter_id]); break; + case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; + case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; + case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; + case kThreadX: + if (axes[iter_id]->iter_type == kCommReduce) { + const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); + stage.bind(axes[iter_id], thread_x); + stage.set_store_predicate(thread_x->var == 0); + } else { + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); + } + break; + case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; + case kNone: break; + default: LOG(FATAL) << "Invalid Annotation " << annotation; break; + } +} + +std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; + if (bind_reduce_iter) { + ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; + } + + ss << "s[" << CleanName(stage->op->func_name()) << "]."; + switch (annotation) { + case kUnroll: ss << "unroll("; break; + case kVectorize: ss << "vectorize("; break; + case kParallel: ss << "parallel("; break; + case kVThread: + case kBlockX: + case kBlockY: + case kThreadX: + case kThreadY: ss << "bind("; break; + case kNone: break; + default: + LOG(FATAL) << "Invalid annotation " << annotation; break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; + case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; + case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; + case kThreadX: + if (bind_reduce_iter) { + ss << ", thread_x"; + } else { + ss << ", tvm.thread_axis(\"threadIdx.x\")"; + } + break; + case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; + default: break; + } + ss << ")\n"; + + if (bind_reduce_iter) { + ss << "s[" << CleanName(stage->op->func_name()) << "]" + << ".set_store_predicate(thread_x.var.equal(0))\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute at **********/ +ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; + return ComputeAtStep(node); +} + +void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const IterVar& target_axis = + (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; + stage.compute_at((*stages)[target_stage_id], target_axis); +} + +std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" + << CleanName(target_stage->op->func_name()) << "], " + << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); + + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute Root **********/ +ComputeRootStep ComputeRootStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeRootStep(node); +} + +void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_root(); +} + +std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Compute Inline **********/ +ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeInlineStep(node); +} + +void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_inline(); +} + +std::string ComputeInlineStepNode::PrintAsPythonAPI( + std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Pack for vec **********/ +PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->vec_size = vec_size; + return PackForVecStep(node); +} + +void PackForVecStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + LOG(FATAL) << "Not implemented"; +} + +std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + LOG(FATAL) << "Not implemented"; + return ""; +} + +/********** Cache read **********/ +CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, + const std::vector& reader_stage_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + node->reader_stage_ids = reader_stage_ids; + return CacheReadStep(node); +} + +te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array readers; + for (const auto& i : reader_stage_ids) { + readers.push_back((*stages)[i]->origin_op); + } + auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); + + const auto& new_stage = (*schedule)[out->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id + 1, new_stage); + + return out; +} + +std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + auto stage = (*stages)[stage_id]; + std::vector reader_stages; + for (size_t i = 0; i < reader_stage_ids.size(); ++i) { + reader_stages.push_back((*stages)[reader_stage_ids[i]]); + } + + auto out = ApplyToSchedule(stages, stage_to_axes, schedule); + + ss << CleanName(out->op->func_name()) << " = " + << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" + << scope_name << "\", [" + << CleanName(reader_stages[0]->op->func_name()); + for (size_t i = 1; i < reader_stage_ids.size(); ++i) { + ss << ", " << CleanName(reader_stages[i]->op->func_name()); + } + ss << "])\n"; + + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)\n"; + + return ss.str(); +} + +/********** Cache write **********/ +CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + return CacheWriteStep(node); +} + +Array CacheWriteStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array tensor_array; + // If the target stage has multi outputs, TVM requires to cache_write + // all of them or schedule.cache_write will raise an error + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + tensor_array.push_back(stage->origin_op.output(i)); + } + auto outs = schedule->cache_write(tensor_array, scope_name); + + UpdateStageAxis(stage, stage_to_axes); + // Even if there is multi outputs, TVM schedule only generate one + // new stage + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + te::Stage stage = (*stages)[stage_id]; + + auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()) << ", "; + } + ss << "= " << "s.cache_write([" + << CleanName(stage->op.output(0)->op->name); + for (auto i = 1; i < stage->op->num_outputs(); ++i) { + ss << ", " << CleanName(stage->op.output(i)->op->name); + } + ss << "], \"" << scope_name << "\")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + return ss.str(); +} + +/********** Pragma **********/ +PragmaStep PragmaStepNode::make(int stage_id, int iter_id, + std::string pragma_type) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->pragma_type = std::move(pragma_type); + return PragmaStep(node); +} + +void PragmaStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + stage.pragma(axes[iter_id], "auto_unroll_max_step", value); + stage.pragma(axes[iter_id], "unroll_explicit", true); + } else { + stage.pragma(axes[iter_id], pragma_type); + } +} + +std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"auto_unroll_max_step\", " << value << ")\n"; + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"unroll_explicit\", True)\n"; + } else { + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" + << pragma_type << "\")\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Rfactor **********/ +RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor_iter_id = factor_iter_id; + return RfactorStep(node); +} + +Array RfactorStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + const auto& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + const te::Tensor& tensor = stage->origin_op.output(0); + const IterVar& axis = axes[iter_id]; + auto outs = schedule->rfactor(tensor, axis, factor_iter_id); + + UpdateStageAxis(stage, stage_to_axes); + + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); + const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); + + const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()); + if (i != outs.size() - 1) { + ss << ", "; + } + } + ss << " = " << "s.rfactor(" + << tensor_name << ", " + << axis_name << ", " + << factor_iter_id << ")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + const auto& output = (*stages)[stage_id + 1]->op.output(0); + const auto& iters = output->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.axis)" + << " + " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.reduce_axis)\n"; + + return ss.str(); +} + +/********** StorageAlign **********/ + +StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, + int factor, int offset) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor = factor; + node->offset = offset; + return StorageAlignStep(node); +} + +void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + stage.storage_align(axes[iter_id], factor, offset); +} + +std::string StorageAlignStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << factor << ", " << offset << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +// Maker for other classes +Iterator IteratorNode::make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + return Iterator(node); +} + +Stage StageNode::make(te::Operation op) { + auto node = make_object(); + if (op->IsInstance()) { + node->op_type = kCompute; + auto *pop = op.as(); + + for (const auto& axis : pop->axis) { + node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), + axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), + axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->auto_unroll_max_step = 0; + node->storage_offset = 0; + return Stage(node); +} + +Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->auto_unroll_max_step = auto_unroll_max_step; + node->storage_offset = storage_offset; + return Stage(node); +} + +Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->auto_unroll_max_step = auto_unroll_max_step; + node->storage_offset = storage_offset; + return Stage(node); +} + +State StateNode::make_empty_state() { + auto node = make_object(); + node->attach_map = AttachMapNode::make(); + node->complete = false; + node->aux_info = ObjectRef(); + return State(node); +} + +State StateNode::make(const Array& ops) { + auto node = make_object(); + for (const auto& op : ops) { + node->stages.push_back(StageNode::make(op)); + } + node->attach_map = AttachMapNode::make(); + node->complete = true; + node->aux_info = ObjectRef(); + return State(node); +} + +State StateNode::make(const std::vector& stages, + const std::vector& transform_steps, + bool complete, ObjectRef aux_info) { + auto node = make_object(); + node->stages = stages; + node->transform_steps = transform_steps; + node->attach_map = AttachMapNode::make(); + node->complete = complete; + node->aux_info = std::move(aux_info); + return State(node); +} + +AttachMap AttachMapNode::make() { + auto node = make_object(); + return AttachMap(node); +} + +// Schedule primitives api +void State::reorder(int stage_id, const std::vector& order) { + const Stage& stage = operator->()->stages[stage_id]; + + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + "should be specified"; + std::vector after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStepNode::make(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +std::vector State::split(int stage_id, + const Iterator& it, const std::vector& lengths, bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + + SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, + inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +std::vector State::follow_split(int stage_id, + const Iterator& it, int src_step_id, int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowSplitStep step = FollowSplitStepNode::make(stage_id, + GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return DoFollowSplitStep(step); +} + + +std::vector State::follow_fused_split(int stage_id, const Iterator& it, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, + GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return DoFollowFusedSplitStep(step); +} + +Iterator State::fuse(int stage_id, const std::vector& iters) { + const Stage& stage = operator->()->stages[stage_id]; + std::vector indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStepNode::make(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +Iterator State::vectorize(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), + kVectorize); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::parallel(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), + kParallel); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), + kUnroll); + + // don't unroll if the extent is larger than max_unroll + if (max_unroll != -1 && it->range.defined()) { + if (auto imm = it->range->extent.as()) { + if (imm->value > max_unroll) { + return it; + } + } + } + + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { + const Stage& target_stage = operator->()->stages[target_stage_id]; + ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, + GetIndex(target_stage->iters, target_iter)); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeAtStep(step); +} + +void State::compute_root(int stage_id) { + ComputeRootStep step = ComputeRootStepNode::make(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeRootStep(step); +} + +void State::compute_inline(int stage_id) { + ComputeInlineStep step = ComputeInlineStepNode::make(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeInlineStep(step); +} + +void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { + const Stage& stage = operator->()->stages[stage_id]; + PackForVecStep step = PackForVecStepNode::make(stage_id, + GetIndex(stage->iters, target_iter), vec_size); + CopyOnWrite()->transform_steps.push_back(step); + return DoPackForVecStep(step); +} + +Iterator State::bind_thread(int stage_id, const Iterator& it, + IteratorAnnotation thread_type) { + const Stage& stage = operator->()->stages[stage_id]; + if (thread_type < kVThread || thread_type > kThreadY) { + LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " + << "kThreadY"; + } + AnnotationStep step = AnnotationStepNode::make(stage_id, + GetIndex(stage->iters, it), thread_type); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +int State::cache_read(int stage_id, const std::string& scope_name, + const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { + CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); + CopyOnWrite()->transform_steps.push_back(step); + return DoCacheReadStep(step, task_dag); +} + +int State::cache_write(int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); + CopyOnWrite()->transform_steps.push_back(step); + return DoCacheWriteStep(step, task_dag); +} + +void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { + const Stage& stage = operator->()->stages[stage_id]; + PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), + pragma_type); + CopyOnWrite()->transform_steps.push_back(step); + return DoPragmaStep(step); +} + +int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, + const ComputeDAG& task_dag) { + const Stage& stage = operator->()->stages[stage_id]; + RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoRfactorStep(step, task_dag); +} + +void State::storage_align(int stage_id, const Iterator& it, int factor, + int offset) { + const Stage& stage = operator->()->stages[stage_id]; + StorageAlignStep step = StorageAlignStepNode::make(stage_id, + GetIndex(stage->iters, it), factor, offset); + CopyOnWrite()->transform_steps.push_back(step); + return DoStorageAlignStep(step); +} + +// Steps' implementations +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + std::vector iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(iters), stage->compute_at, + stage->auto_unroll_max_step, + stage->storage_offset); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +std::vector State::DoSplitStepCommon(int stage_id, int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + size_t old_iter_size = stage->iters.size(); + + PrimExpr tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = tosplit_extent = PrimExpr(); + } + + std::vector outs; + for (size_t i = 0; i < lengths.size(); ++i) { + PrimExpr l; std::string name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { + res = IteratorNode::make(name, Range::make_by_min_extent(tosplit_min, l), + it->iter_type, kNone); + tosplit_min = 0; + tosplit_extent = indexdiv(tosplit_extent + l - 1, l); + } else { + res = IteratorNode::make(name, Range(), it->iter_type, kNone); + tosplit_min = tosplit_extent = PrimExpr(); + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min.defined() && tosplit_extent.defined()) { + range = Range::make_by_min_extent(tosplit_min, tosplit_extent); + } + if (inner_to_outer) { + outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); + std::reverse(outs.begin(), outs.end()); + } else { + outs.push_back(IteratorNode::make(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); + } + + std::vector new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, + stage->storage_offset); + + // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + for (size_t i = iter_id; i < old_iter_size; ++i) { + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i + lengths.size()); + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return outs; +} + +std::vector State::DoSplitStep(const SplitStep& step) { + return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, + step->inner_to_outer); +} + +std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { + std::vector lengths; + step->ExtractSplitLengths(operator->()->transform_steps, &lengths); + return DoSplitStepCommon(step->stage_id, step->iter_id, lengths, true); +} + +std::vector State::DoFollowFusedSplitStep(const FollowFusedSplitStep& step) { + const PrimExpr& length = step->ExtractSplitLength(operator->()->transform_steps); + return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); +} + +Iterator State::DoFuseStep(const FuseStep& step) { + int stage_id = step->stage_id; + const Stage& stage = operator->()->stages[stage_id]; + int old_iter_size = static_cast(stage->iters.size()); + + std::string new_name; + PrimExpr new_extent = 1; + IteratorType new_iter_type = kSpecial; + + std::vector ori_iters; + for (size_t i = 0; i < step->fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(step->fused_ids[i], step->fused_ids[i-1] + 1); + } + + if (i != step->fused_ids.size() - 1) { + const auto& iter_to_attached_stage = operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, step->fused_ids[i])) + != iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " + "that have been attached by some stages"; + } + } + + const Iterator& it = stage->iters[step->fused_ids[i]]; + ori_iters.push_back(it); + new_name += it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_type = it->iter_type; + } else { + if (new_iter_type != it->iter_type) { + new_iter_type = kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::make_by_min_extent(0, new_extent); + } + Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); + std::vector new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + step->fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, + stage->storage_offset); + + // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + const int begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); + for (int i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return new_it; +} + +Iterator State::DoAnnotationStep(const AnnotationStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Iterator it = stage->iters[step->iter_id]; + + Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + step->annotation, &it->ori_iters); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = std::move(new_stage); + return new_it; +} + +void State::DoComputeAtStep(const ComputeAtStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + // after compute_at, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + std::vector new_iters; + for (const Iterator& it : stage->iters) { + size_t s = it->name.size(); + if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && it->name[s-1] <= '4') { + // We use a dangerous heuristic rule here : For multi level splitted iterators, we assume + // their length does not change after compute_at. + // Reason: These iterators are generated in MultiStagePolicy by multi level tiling, they will + // be carefully compute_at their consumers. In this case, their lengths do not change. + // We do this to keep the AnnotateCPU pass to annotate more efficiently. + new_iters.push_back(it); + } else { + new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters)); + } + } + + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), kIter, stage->auto_unroll_max_step, + stage->storage_offset); + pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); +} + +void State::DoComputeRootStep(const ComputeRootStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + // after compute_root, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + std::vector new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters)); + } + + // update attach map + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), kRoot, stage->auto_unroll_max_step, + stage->storage_offset); + pstate->attach_map.DeleteStage(step->stage_id); +} + +void State::DoComputeInlineStep(const ComputeInlineStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + StateNode* pstate = CopyOnWrite(); + + // CHECK the validity of compute_inline + const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; + for (size_t i = 0; i < stage->iters.size(); ++i) { + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) + << "Invalid compute_inline: Because there are some other stages " + "that are attached to the target stage"; + } + + pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; + pstate->attach_map.DeleteStage(step->stage_id); +} + +void State::DoPackForVecStep(const PackForVecStep& step) { + LOG(FATAL) << "Not implemented"; +} + +// Common part for steps that add new stages (e.g. CacheReadStep, CacheWriteStep, RfactorStep) +void AddStageModificationSteps(size_t step_id, const std::vector& transform_steps, + std::vector* replay_steps) { + const Step& step = transform_steps[step_id]; + if (step->IsInstance() || step->IsInstance()) { + replay_steps->push_back(step); + } else if (step->IsInstance()) { + // add FuseStepNode required by rfactor + if (step_id >= 2 && transform_steps[step_id - 2]->IsInstance()) { + const Step& fuse_step = transform_steps[step_id - 2]; + if (fuse_step->stage_id == step->stage_id) { + replay_steps->push_back(fuse_step); + } + } + // add SplitStepNode required by rfactor + CHECK_GE(step_id, 1); + CHECK(transform_steps[step_id - 1]->IsInstance()); + const Step& split_step = transform_steps[step_id - 1]; + CHECK_EQ(split_step->stage_id, step->stage_id); + replay_steps->push_back(split_step); + // add RfactorStepNode + replay_steps->push_back(step); + } +} + +int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { + StateNode* pstate = CopyOnWrite(); + std::vector replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(step)) { + break; + } + } + dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + + // target -> target + target_store + // Should update target's op, insert new stage, update the later stage's op + pstate->stages[step->stage_id].CopyOnWrite()->op = + operator->()->task_dag->ops[step->stage_id]; + pstate->stages.insert(pstate->stages.begin() + step->stage_id + 1, + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); + for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; + } + pstate->attach_map = + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id + 1, 1); + + return step->stage_id + 1; +} + +int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { + StateNode* pstate = CopyOnWrite(); + std::vector replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(step)) { + break; + } + } + dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + + // target -> target_compute + target + // Assume target stage has never been applied any steps before cache_write + // Should insert new stage, update target stage, update the later stage's op + pstate->stages.insert(pstate->stages.begin() + step->stage_id, + StageNode::make(operator->()->task_dag->ops[step->stage_id])); + pstate->stages[step->stage_id + 1] = + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; + } + pstate->attach_map = + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + + return step->stage_id; +} + +void State::DoPragmaStep(const PragmaStep& step) { + if (step->pragma_type == "debug_skip_region") { + StateNode* pstate = CopyOnWrite(); + pstate->attach_map.DeleteStage(step->stage_id); + } else if (StrStartsWith(step->pragma_type, "auto_unroll_max_step")) { + StateNode* pstate = CopyOnWrite(); + StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); + size_t pos = step->pragma_type.find('$'); + stage->auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); + } else if (step->pragma_type == "tensor_core") { + // Nothing needs to be done here + } else { + LOG(FATAL) << "Invalid pragma: " << step->pragma_type; + } +} + +int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { + StateNode* pstate = CopyOnWrite(); + const auto compute_at_type = pstate->stages[step->stage_id]->compute_at; + std::vector replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(step)) { + break; + } + } + dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + + // target -> target_compute + target + // Should insert new stage, update target stage, update the later stage's op + pstate->stages.insert(pstate->stages.begin() + step->stage_id, + StageNode::make(operator->()->task_dag->ops[step->stage_id])); + // maintain the compute_at type of target stage + Stage target_stage = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + target_stage.CopyOnWrite()->compute_at = compute_at_type; + pstate->stages[step->stage_id + 1] = target_stage; + + for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; + } + pstate->attach_map = + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + + return step->stage_id; +} + +void State::DoStorageAlignStep(const StorageAlignStep& step) { + StateNode* pstate = CopyOnWrite(); + StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); + stage->storage_offset = step->offset; +} + +void State::DoStep(const Step& step, const ComputeDAG& dag) { + if (auto ps = step.as()) { + DoReorderStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFollowSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFollowFusedSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFuseStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoAnnotationStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeAtStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeRootStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeInlineStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoPackForVecStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoCacheReadStep(GetRef(ps), dag); + } else if (auto ps = step.as()) { + DoCacheWriteStep(GetRef(ps), dag); + } else if (auto ps = step.as()) { + DoPragmaStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoRfactorStep(GetRef(ps), dag); + } else if (auto ps = step.as()) { + DoStorageAlignStep(GetRef(ps)); + } else { + LOG(FATAL) << "Invalid step: " << step; + } +} + +void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; + + for (const auto& step : steps) { + if (complete_rate >= 0 && ct++ > steps.size() * complete_rate) { + break; + } + DoStep(step, dag); + } +} + + +void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, + bool delete_trivial_loop) { + const Stage& stage = state->stages[stage_id]; + + if (stage->auto_unroll_max_step != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->func_name() << " auto_unroll: " + << stage->auto_unroll_max_step << "\n"; + } + if (stage->storage_offset != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->func_name() << " storage_offset: " + << stage->storage_offset << "\n"; + } + + size_t indent = 0; + for (size_t i = 0; i < stage->iters.size(); ++i) { + const Iterator& iter = stage->iters[i]; + + if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + switch (iter->annotation) { + case kNone: *os << "for "; break; + case kUnroll: *os << "unroll "; break; + case kParallel: *os << "parallel "; break; + case kVectorize: *os << "vectorize "; break; + case kVThread: *os << "vthread "; break; + case kBlockX: *os << "gpu.blockIdx.x "; break; + case kBlockY: *os << "gpu.blockIdx.y "; break; + case kThreadX: *os << "gpu.threadIdx.x "; break; + case kThreadY: *os << "gpu.threadIdx.y "; break; + } + if (iter->range.defined()) { + *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")" << "\n"; + } else { + *os << iter->name << " (None)" << "\n"; + } + + indent += 2; + } + + if (state != nullptr) { + AttachMap::IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + for (const auto& attach_stage_id : pair->second) { + PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); + } + } + } + } + + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + *os << stage->op->func_name() << " = ...\n"; +} + +void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { + // Gather placeholders + std::vector placeholders; + for (const auto& stage : node->stages) { + if (stage->op_type == kPlaceholder) { + placeholders.push_back(stage->op->name); + } + } + + *os << "Placeholder: "; + for (size_t i = 0; i < placeholders.size(); ++i) { + *os << placeholders[i]; + if (i != placeholders.size() - 1) { + *os << ", "; + } + } + *os << "\n"; + + // Print all stages + for (size_t i = 0; i < node->stages.size(); ++i) { + const Stage& stage = node->stages[i]; + if (stage->op_type == kPlaceholder) { + continue; + } else if (stage->op_type == kCompute) { + if (stage->compute_at == kRoot) { + PrintStage(os, i, node, 0, delete_trivial_loop); + } + } else { + LOG(FATAL) << "Invalid op type"; + } + } +} + +std::string State::ToStr(bool delete_trivial_loop) const { + std::ostringstream os; + PrintState(&os, operator->(), delete_trivial_loop); + return os.str(); +} + +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode *pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMapNode::make(); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h new file mode 100644 index 000000000000..3ffe8a7feafb --- /dev/null +++ b/src/ansor/loop_state.h @@ -0,0 +1,732 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/interfaces.h + * \brief Data structures for loop transformations + + * Basically this is a simplified TVM IR with schedule primitives. + * We don't use the existing TVM IR because + * 1. We want fast incremental change to the loop structures + * 2. We want serializable history for replay and backtracking + * 3. We want simplified IR for easy and clean feature extraction + * 4. We may create some Macro schedule primitives + + * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. + * Because we share a lot common objects during search, the transformation is + * implemented in copy on write style. All objects are immutable, which is + * similar to TVM IR. + */ + +#ifndef TVM_ANSOR_LOOP_STATE_H_ +#define TVM_ANSOR_LOOP_STATE_H_ + +// #include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include "expr_hasher.h" +#include "utils.h" +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY +}; + +enum StageType { + kPlaceholder, kCompute +}; + +enum ComputeAtType { + kRoot, // compute at root + kInlined, // inlined + kIter, // compute at some iterator +}; + +/* Iterator and Stage */ +class Iterator; class Stage; class State; + +/*! + * \brief An for loop iterator + * Similar to tvm::IterVar in `include/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; // domain of for loop range + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; + + static Iterator make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr); + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); + +/*! + * \brief A stage in the compute declaration + * Similar to te::Stage in `include/schedule.h` + */ +class StageNode : public Object { + public: + te::Operation op; + StageType op_type; + std::vector iters; + ComputeAtType compute_at; + int16_t auto_unroll_max_step; + int storage_offset; + + static Stage make(te::Operation op); + static Stage make(te::Operation op, StageType op_type, const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); + static Stage make(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); + + static constexpr const char *_type_key = "ansor.Stage"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(Stage, ObjectRef, StageNode); + + +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); + +/* + * Note on how to add a new transform step + * + * Take fuse for example: + * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function + * in FuseStepNode::make(...) loop_state.cc + * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. + * - In these two functions you need to lower this step with tvm's schedule API + * 3. Implement State::fuse and State::DoFuseStep. + * - In these two functions you need to incrementally update all data structures in State with + * CopyOnWrite style + * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 5. Add serialization support in `struct Handler >` + * (in serialization.cc) + * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + */ + +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class PackForVecStep; class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; +class AttachMap; + +class ReorderStepNode: public StepNode { + public: + std::vector after_ids; + + static ReorderStep make(int stage_id, const std::vector& after_ids); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); + + +class SplitStepNode: public StepNode { + public: + int iter_id; + PrimExpr extent; // the extent of the axis to split + std::vector lengths; // The split factors + bool inner_to_outer; + + static SplitStep make(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.SplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); + +// Similar to SplitStepNode, but use split factor from another step(i.e. Follow another split step) +class FollowSplitStepNode: public StepNode { + public: + int iter_id; + int src_step_id; + int n_split; + + static FollowSplitStep make(int stage_id, int iter_id, + int src_step_id, int n_split); + + void ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); + + +// Similar to FollowSplitStep, but use split factors from multiple steps +// This can be used for the split in cooperative fetching. +class FollowFusedSplitStepNode: public StepNode { + public: + int iter_id; + std::vector src_step_ids; + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + static FollowFusedSplitStep make(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts); + + PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + + +class FuseStepNode: public StepNode { + public: + std::vector fused_ids; + + static FuseStep make(int stage_id, const std::vector& fused_ids); + + IterVar ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); + + +class AnnotationStepNode: public StepNode { + public: + int iter_id; + IteratorAnnotation annotation; + + static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); + + +class ComputeAtStepNode: public StepNode { + public: + int target_stage_id; + int target_iter_id; + + static ComputeAtStep make(int stage_id, int target_stage_id, int target_iter_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); + + +class ComputeRootStepNode: public StepNode { + public: + static ComputeRootStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); + + +class ComputeInlineStepNode: public StepNode { + public: + static ComputeInlineStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); + +class PackForVecStepNode: public StepNode { + public: + int iter_id; + int vec_size; + + static PackForVecStep make(int stage_id, int iter_id, int vec_size); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PackForVecStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); + + +/*! \brief Apply cache_read to a stage + * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ +class CacheReadStepNode: public StepNode { + public: + std::string scope_name; + std::vector reader_stage_ids; + + static CacheReadStep make(int stage_id, std::string scope_name, + const std::vector& reader_stage_id); + + te::Tensor ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheReadStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); + + +/*! \brief Apply cache_write to a stage + * TVM Api: te::Schedule::cache_write(tensor, scope) + * This step will cache_write all output tensors of target stage */ +class CacheWriteStepNode: public StepNode { + public: + std::string scope_name; + + static CacheWriteStep make(int stage_id, std::string scope_name); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheWriteStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); + +/*! \brief Add pragma to a specific iterator */ +class PragmaStepNode: public StepNode { + public: + int iter_id; + std::string pragma_type; + + static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PragmaStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); + +/*! \brief Factor a reduction axis + * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ +class RfactorStepNode: public StepNode { + public: + int iter_id; + int factor_iter_id; + + static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.RfactorStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); + +class StorageAlignStepNode: public StepNode { + public: + int iter_id; + int factor; + int offset; + + static StorageAlignStep make(int stage_id, int iter_id, int factor, + int offset); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.StorageAlignStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); + +/*! \brief stores the compute_at relation between stages */ +class AttachMapNode: public Object { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + std::unordered_map stage_to_attach_iter; + std::unordered_map> iter_to_attached_stages; + + static AttachMap make(); + + static constexpr const char* _type_key = "ansor.AttachMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); +}; + +/*! \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + * + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations */ +class AttachMap : public ObjectRef { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + void DeleteStage(int stage_id); + void ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters); + AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); + + private: + static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); +}; + +/*! \brief The loop state and corresponding history steps to reach this state */ +class StateNode: public Object { + public: + std::vector stages; // Current stages and loop structures + std::vector transform_steps; // History transformation steps to reach this state + bool complete; // Indicate whether this state has unfilled tile sizes + AttachMap attach_map; // stores the compute_at relation between stages + ObjectRef aux_info; // Used to store any auxiliary info about this state + ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. + // The default value is an empty NodeRef + // (means no modification to the DAG) + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("complete", &complete); + v->Visit("aux_info", &aux_info); + } + + static State make_empty_state(); + static State make(const Array& ops); + static State make(const std::vector& stages, + const std::vector& transform_steps, bool complete, ObjectRef aux_info); + + static constexpr const char* _type_key = "ansor.State"; + TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); +}; + +/*! \brief The loop state and corresponding history steps to reach this state */ +class State : public ObjectRef { + public: + // Schedule primitives + void reorder(int stage_id, const std::vector& order); + std::vector split(int stage_id, const Iterator& it, + const std::vector& lengths, + bool inner_to_outer = true); + std::vector follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split); + std::vector follow_fused_split(int stage_id, const Iterator& it, + const std::vector& src_step_ids, int level, bool factor_or_nparts); + Iterator fuse(int stage_id, const std::vector& iters); + Iterator vectorize(int stage_id, const Iterator& it); + Iterator parallel(int stage_id, const Iterator& it); + Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + // Valide thread_type: kVThread, kBlockX, kThreadX, kThreadY + Iterator bind_thread(int stage_id, const Iterator& it, + IteratorAnnotation thread_type); + void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + void compute_root(int stage_id); + void compute_inline(int stage_id); + void pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size); + int cache_read(int stage_id, const std::string& scope_name, + const std::vector& reader_stage_ids, + const ComputeDAG& task_dag); + int cache_write(int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + int rfactor(int stage_id, const Iterator& it, int factor_iter_id, + const ComputeDAG& task_dag); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); + + /* We separate these functions out, so you can call them for replay easily given history steps */ + void DoReorderStep(const ReorderStep& step); + std::vector DoSplitStep(const SplitStep& step); + std::vector DoFollowSplitStep(const FollowSplitStep& step); + std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); + Iterator DoFuseStep(const FuseStep& step); + Iterator DoAnnotationStep(const AnnotationStep& step); + void DoComputeAtStep(const ComputeAtStep& step); + void DoComputeRootStep(const ComputeRootStep& step); + void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPackForVecStep(const PackForVecStep& step); + int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); + int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); + void DoPragmaStep(const PragmaStep& step); + int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); + void DoStorageAlignStep(const StorageAlignStep& step); + + /* Do transform steps + * Note: The following function only change loop state. They do not change transform_history. */ + void DoStep(const Step& step, const ComputeDAG& dag); + void DoSteps(const std::vector& step, const ComputeDAG& dag); + + // Print to str + std::string ToStr(bool delete_trivial_loop = true) const; + + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); + + private: + // common function for DoSplitStep and DoFollowSplitStep + std::vector DoSplitStepCommon(int stage_id, int iter_id, + const std::vector& lengths, + bool inner_to_outer); +}; + +} // namespace ansor +} // namespace tvm + + +// Hash and equal function for State, Stage, Iterator and Step +namespace std { + +template <> +struct hash<::tvm::ansor::Step> { + std::size_t operator()(const ::tvm::ansor::Step& step) const { + if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { + return ::dmlc::HashCombine(1, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->after_ids)); + } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { + size_t ret = ::dmlc::HashCombine(2, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->inner_to_outer))); + for (const auto& len : ps->lengths) { + if (len.defined()) { + auto pint = len.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); + } else { + ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number + } + return ret; + } + } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { + return ::dmlc::HashCombine(3, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->src_step_id), + ps->n_split)))); + } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { + return ::dmlc::HashCombine(4, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), + ::dmlc::HashCombine(std::hash()(ps->level), + ps->factor_or_nparts))))); + } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { + return ::dmlc::HashCombine(5, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->fused_ids)); + } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { + return ::dmlc::HashCombine(6, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + static_cast(ps->annotation)))); + } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { + return ::dmlc::HashCombine(7, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->target_stage_id), + ps->target_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { + return ::dmlc::HashCombine(8, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { + return ::dmlc::HashCombine(9, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { + return ::dmlc::HashCombine(10, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->vec_size))); + } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { + return ::dmlc::HashCombine(11, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->scope_name), + ps->reader_stage_ids))); + } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { + return ::dmlc::HashCombine(12, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->scope_name)); + } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { + return ::dmlc::HashCombine(13, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->pragma_type))); + } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { + return ::dmlc::HashCombine(14, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->factor_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->factor), + ps->offset)))); + } else { + LOG(FATAL) << "Invalid step"; + } + return 0; + } +}; + +template <> +struct hash<::tvm::ansor::State> { + std::size_t operator()(const ::tvm::ansor::State& state) const { + return std::hash()(state.ToStr()); + } +}; + +template <> +struct equal_to<::tvm::ansor::State> { + bool operator() (const ::tvm::ansor::State& lhs, + const ::tvm::ansor::State& rhs) const { + return lhs.ToStr() == rhs.ToStr(); + } +}; + +} // namespace std + +#endif // TVM_ANSOR_LOOP_STATE_H_ diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc new file mode 100644 index 000000000000..2018cf33d1a2 --- /dev/null +++ b/src/ansor/utils.cc @@ -0,0 +1,102 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "utils.h" +#include + +namespace tvm { +namespace ansor { + + +NullStream& NullStream::Global() { + static NullStream stream; + return stream; +} + +const std::vector >& SplitFactorizationMemo::GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor) { + QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); + auto it = memory_.find(key); + if (it != memory_.end()) { + return it->second; + } + + tmp_stack_.assign(n_lengths, PrimExpr()); + results_ = &memory_[key]; + n_lengths_ = n_lengths; + + DfsEnumerate(0, extent, max_innermost_factor); + + return *results_; +} + +void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor) { + if (now == n_lengths_) { + if (tmp_stack_.back().as()->value <= max_innermost_factor) { + results_->push_back(tmp_stack_); + } + } else { + for (const auto& f : GetFactors(remaining_lenght)) { + tmp_stack_[now] = PrimExpr(f); + DfsEnumerate(now + 1, remaining_lenght / f, max_innermost_factor); + } + } +} + +const std::vector& SplitFactorizationMemo::GetFactors(int n) { + auto it = factor_memory_.find(n); + if (it != factor_memory_.end()) { + return it->second; + } + + std::vector& res = factor_memory_[n]; + int step = n % 2 == 0 ? 1 : 2; + for (size_t i = 1; i < static_cast(std::sqrt(n)) + 1; i += step) { + if (n % i == 0) { + res.push_back(i); + if (n / i != i) { + res.push_back(n/i); + } + } + } + std::sort(res.begin(), res.end()); + return res; +} + +ThreadPool& ThreadPool::Global() { + static ThreadPool* pool = new ThreadPool(); + static int ct = 0; + + ct = (ct + 1) % ThreadPool::REFRESH_EVERY; + + if (ct == 0) { + pool->Abort(); + delete pool; + pool = new ThreadPool(); + } + + if (pool->NumWorkers() == 0) { + pool->Launch(std::thread::hardware_concurrency()); + } + + return *pool; +} + +TVM_REGISTER_GLOBAL("ansor.utils.GetFactorizationSchemes") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int extent = args[0]; + int n_lengths = args[1]; + int max_innermost_factor = args[2]; + SplitFactorizationMemo memo; + + Array > result; + for (const auto& lens : memo.GetFactorizationSchemes(extent, n_lengths, max_innermost_factor)) { + result.push_back(lens); + } + + *ret = result; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/utils.h b/src/ansor/utils.h new file mode 100644 index 000000000000..4ea7f283ad09 --- /dev/null +++ b/src/ansor/utils.h @@ -0,0 +1,482 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/utils.h + * \brief Common utilities + */ + +#ifndef TVM_ANSOR_UTILS_H_ +#define TVM_ANSOR_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace std { + +// hash function for std::pair, std::vector and std::tuple +template +struct hash > { + std::size_t operator()(const std::pair& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } +}; + +template +struct hash > { + std::size_t operator()(const std::tuple& k) const { + return ::dmlc::HashCombine( + ::dmlc::HashCombine(std::hash()(std::get<0>(k)), std::hash()(std::get<1>(k))), + std::hash()(std::get<2>(k))); + } +}; + +template +struct hash > { + std::size_t operator()(const std::vector& vec) const { + if (vec.empty()) { + return 0; + } + std::size_t ret = std::hash()(vec[0]); + for (size_t i = 1; i < vec.size(); ++i) { + ret = ::dmlc::HashCombine(ret, std::hash()(vec[i])); + } + return ret; + } +}; + +} // namespace std + +namespace tvm { +namespace ansor { + +/*! \brief Macro to make it easy to define mutable node ref type given node */ +#define TVM_DEFINE_MUTABLE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ObjectRef { \ + public: \ + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + }; \ + +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_OBJECT_REF_COW_METHOD(NodeName); \ + }; + +/********** Utilities for std::vector, std::set **********/ + +/*! \brief Get the first appearance index of elements in a vector */ +template +inline void GetIndices(const std::vector& array, + const std::vector& to_locate, + std::vector* indices) { + for (const auto& v : to_locate) { + auto it = std::find(array.begin(), array.end(), v); + if (it != array.end()) { + indices->push_back(it - array.begin()); + } else { + LOG(FATAL) << "Cannot find the item"; + } + } +} + +/*! \brief Get the first appearance index of an element in a vector */ +template +inline int GetIndex(const std::vector& array, const T& to_locate) { + for (size_t i = 0; i < array.size(); ++i) { + if (array[i] == to_locate) { + return i; + } + } + LOG(FATAL) << "Cannot find the item"; + return -1; +} + +/*! \brief Delete an element in a vector */ +template +inline void DeleteItem(std::vector* array, const T& to_delete) { + auto iter = std::find(array->begin(), array->end(), to_delete); + if (iter != array->end()) { + array->erase(iter); + } +} + +/*! \brief Compute the product of all elements in a vector */ +inline int64_t ElementProduct(const std::vector& array) { + int64_t ret = 1; + for (auto x : array) { + ret *= x; + } + return ret; +} + +/* \brief Get the maximum element in a vector */ +template +T MaximumElement(const std::vector& array) { + CHECK(!array.empty()); + const T* pmax = &array[0]; + for (size_t i = 1; i < array.size(); ++i) { + if (array[i] > *pmax) { + pmax = &array[i]; + } + } + return *pmax; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* in) { + out->insert(out->end(), std::make_move_iterator(in->begin()), + std::make_move_iterator(in->end())); + return *out; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args... args) { + ConcatenateMove(out, first); + ConcatenateMove(out, args...); + return *out; +} + +/* \brief Get a random permutation of integers [0, n-1] */ +template +void RandomPermutation(int n, std::vector* out, G* gen) { + out->assign(n, 0); + std::iota(out->begin(), out->end(), 0); + std::shuffle(out->begin(), out->end(), *gen); +} + +/* \brief Random sample without replacement */ +template +void RandomSample(std::vector* in_data, size_t out_size, G* gen) { + // Note: This function is inefficient in the cases when out_size << in_data.size() + out_size = std::min(in_data->size(), out_size); + + if (in_data->size() <= out_size) { // return all + return; + } + std::vector indices; + RandomPermutation(in_data->size(), &indices, gen); + + std::vector tmp_data; + tmp_data.reserve(out_size); + for (size_t i = 0; i < out_size; ++i) { + tmp_data.push_back(std::move((*in_data)[indices[i]])); + } + + *in_data = std::move(tmp_data); +} + +/*! \brief Argsort. Order: largest to smallest */ +template +inline void Argsort(const std::vector& scores, std::vector* index) { + index->clear(); index->reserve(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + index->push_back(i); + } + auto cmp = [&scores](int l, int r) { + return scores[l] > scores[r]; + }; + std::sort(index->begin(), index->end(), cmp); +} + +// Do x++ for all x in the set such that x >= threshold +inline void SetAddOne(std::set* set, int threshold = 0) { + std::set new_set; + for (int x : *set) { + if (x >= threshold) { + new_set.insert(x + 1); + } else { + new_set.insert(x); + } + } + *set = std::move(new_set); +} + +// Compute Jaccard Similarity of two sets +template +double JaccardSimilarity(std::set s1, std::set s2) { + std::vector intersect; + std::set_intersection(s1.begin(), s1.end(), s2.begin(), s2.end(), + std::back_inserter(intersect)); + return 1.0 * intersect.size() / (s1.size() + s2.size() - intersect.size()); +} + +/********** Utilities for std::string **********/ + +/*! Return whether a string ends with a another substring */ +inline bool StrEndsWith(const std::string& a, const std::string& b) { + if (b.size() > a.size()) return false; + return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); +} + +/*! Return whether a string starts with a another substring */ +inline bool StrStartsWith(const std::string& a, const std::string& b) { + if (b.size() > a.size()) return false; + return std::equal(a.begin(), a.begin() + b.size(), b.begin()); +} + +/*! Replace a sub-string to another sub-string in a string */ +inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { + auto pos = base->find(from); + while (pos != std::string::npos) { + base->replace(pos, from.size(), to); + pos = base->find(from, pos + to.size()); + } +} + +/********** Utilities for TVM Containers / ByteArray **********/ + +/*! \brief Compute mean of a FloatImm array */ +inline double FloatArrayMean(const Array& float_array) { + double sum = 0; + if (float_array.empty()) { + return 0.0; + } + + for (const auto&x : float_array) { + auto floatimm = x.as(); + CHECK(floatimm != nullptr); + sum += floatimm->value; + } + return sum / float_array.size(); +} + +/*! \brief Serialize a 2-dimensional vector to TVMByteArray. + * This is used for sending data to python code */ +template +inline TVMByteArray Serialize2dVector(std::vector >&& in_data, + std::vector* out_data) { + size_t total_bytes = 0; + std::vector size_vector; + + // serialize sizes + total_bytes += (1 + in_data.size()) * sizeof(int); + size_vector.reserve(in_data.size() + 1); + size_vector.push_back(in_data.size()); + for (const auto& x : in_data) { + size_vector.push_back(static_cast(x.size())); + total_bytes += sizeof(T) * x.size(); + } + + out_data->reserve(total_bytes); + char* ptr = out_data->data(); + memmove(ptr, reinterpret_cast(size_vector.data()), (1 + in_data.size()) * sizeof(int)); + ptr += (1 + in_data.size()) * sizeof(int); + + // serialize in_data + for (auto& x : in_data) { + memmove(ptr, x.data(), sizeof(T) * x.size()); + ptr += sizeof(T) * x.size(); + x.clear(); + } + + CHECK_EQ(ptr - out_data->data(), total_bytes); + + return TVMByteArray{out_data->data(), total_bytes}; +} + +/********** Other Utilities **********/ + +// Get an int value from an Expr +inline int64_t GetIntImm(const PrimExpr& expr) { + auto pint = expr.as(); + CHECK(pint != nullptr); + return pint->value; +} + + +// Compute the product of the lengths of axes +inline int64_t AxisLengthProd(const Array& axes) { + int64_t ret = 1.0; + for (const auto& x : axes) { + if (const IntImmNode* imm = x->dom->extent.as()) { + ret *= imm->value; + } else { + return -1.0; + } + } + return ret; +} + + +// An empty output stream +class NullStream : public std::ostream { + public: + NullStream() : std::ostream(nullptr) {} + NullStream(const NullStream &) : std::ostream(nullptr) {} + static NullStream& Global(); +}; + +template +NullStream& operator<<(NullStream& os, const T& value) { + return os; +} + +/*! \brief Get std cout with verbose control */ +inline std::ostream& StdCout(int verbose) { + if (verbose >= 1) { + return std::cout; + } else { + return NullStream::Global(); + } +} + +/*! \brief Print a title */ +inline void PrintTitle(const std::string& title, int verbose) { + if (verbose >= 1) { + std::cout << "------------------------------------------------------------" << "\n"; + std::cout << "----------------------- [ " << title << " ]\n"; + std::cout << "------------------------------------------------------------" << std::endl; + } +} + +/*! \brief A simple thread pool */ +class ThreadPool { + public: + void Launch(size_t n = 1) { + for (std::size_t i = 0; i < n; ++i) { + threads_.emplace_back([this] {WorkerFunc();}); + } + } + + void BeginBatch(int n) { + finish_ct_ = n; + is_finished_ = n <= 0; + } + + template::type> + std::future Enqueue(F&& f, Args&&... args) { + std::packaged_task p(std::bind(f, args...)); + + auto r = p.get_future(); + { + std::unique_lock l(m_); + work_.emplace_back(std::move(p)); + } + work_signal_.notify_one(); + return r; + } + + void WaitBatch() { + std::unique_lock l(finish_mutex_); + if (!is_finished_) { + finish_signal_.wait(l); + } + } + + void Abort() { + CancelPending(); + Join(); + } + + void CancelPending() { + std::unique_lock l(m_); + work_.clear(); + } + + void Join() { + { + std::unique_lock l(m_); + for (size_t i = 0; i < threads_.size(); ++i) { + work_.push_back({}); + } + } + work_signal_.notify_all(); + for (auto& t : threads_) { + t.join(); + } + threads_.clear(); + } + + size_t NumWorkers() { + return threads_.size(); + } + + static const int REFRESH_EVERY = 128; + static ThreadPool& Global(); + + ~ThreadPool() { + Join(); + } + + private: + void WorkerFunc() { + while (true) { + std::packaged_task f; + { + std::unique_lock l(m_); + if (work_.empty()) { + work_signal_.wait(l, [&]{ return !work_.empty(); }); + } + f = std::move(work_.front()); + work_.pop_front(); + } + if (!f.valid()) { return; } + f(); + + finish_ct_--; + if (finish_ct_ == 0) { + std::unique_lock l(finish_mutex_); + + is_finished_ = true; + finish_signal_.notify_one(); + } + } + } + + std::mutex m_; + std::condition_variable work_signal_; + std::deque> work_; + std::vector threads_; + + bool is_finished_; + std::mutex finish_mutex_; + std::atomic finish_ct_; + std::condition_variable finish_signal_; +}; + +/*! + * \brief Enumerate all possible factorization schemes for splitting an axes. + * \note This class will memorize the results for reuse. + */ +class SplitFactorizationMemo { + public: + using QueryKey = std::tuple; + + const std::vector >& GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor); + const std::vector& GetFactors(int n); + + private: + void DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor); + + std::unordered_map > > memory_; + + int n_lengths_; + std::vector tmp_stack_; + std::vector >* results_; + std::unordered_map> factor_memory_; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_UTILS_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc new file mode 100644 index 000000000000..b9a4f25023bf --- /dev/null +++ b/tests/cpp/ansor_test.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include "../../src/ansor/compute_dag.h" + +tvm::Array matmul_func(int n, int m, int k) { + using namespace tvm; + using namespace tvm::te; + + Tensor A = placeholder({n, k}, DataType::Float(32), "A"); + Tensor B = placeholder({k, m}, DataType::Float(32), "B"); + IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); + const auto& C = compute( + {n, m}, + [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, + "C"); + + return {A, B, C}; +} + +tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, + int CI, int CO, int kernel_size, int strides, int padding, + int dilation = 1) { + using namespace tvm; + using namespace tvm::te; + + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, + DataType::Float(32), "Kernel"); + Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias"); + Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); + Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); + + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1); + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1); + + const auto& conv = topi::conv2d_nchw(data, kernel, strides, padding, + dilation); + const auto& bias_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return conv[i][j][k][l] + bias[j][0][0]; + }, + "Bias_add"); + const auto& bn_mul = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return bias_add[i][j][k][l] * bn_scale[j][0][0]; + }, + "Bn_mul"); + const auto& bn_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return bn_mul[i][j][k][l] + bn_offset[j][0][0]; + }, + "Bn_add"); + const auto& out = topi::relu(bn_add); + + return {data, kernel, bias, bn_scale, bn_offset, out}; +} + +TEST(ComputeDAG, Basic) { + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + auto dag = tvm::ansor::ComputeDAGNode::make(tensors); + + LOG(INFO) << "\n" << dag; + LOG(INFO) << "\n" << dag->access_analyzer; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} From 9fcbf0bc6e3a0e985ad507dff868458a3124eb6f Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 27 May 2020 18:06:42 +0800 Subject: [PATCH 02/78] Split transform_step out & Update more UTs (#3) * Split transform_step out * Update GetProducers & GetConsumers * Update UTs * Add UT for CacheReadWrite & Some bug fix --- src/ansor/compute_dag.cc | 188 +++---- src/ansor/compute_dag.h | 2 +- src/ansor/loop_state.cc | 1004 +++++------------------------------ src/ansor/loop_state.h | 547 +------------------ src/ansor/transform_step.cc | 820 ++++++++++++++++++++++++++++ src/ansor/transform_step.h | 551 +++++++++++++++++++ tests/cpp/ansor_test.cc | 466 +++++++++++++++- 7 files changed, 2074 insertions(+), 1504 deletions(-) create mode 100644 src/ansor/transform_step.cc create mode 100644 src/ansor/transform_step.h diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 31136985b330..e1ae3250d1a5 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -14,7 +14,7 @@ #include #include #include -// #include "loop_state.h" +#include "loop_state.h" #include "utils.h" // #include "../relay/pass/kernel_layout_transform.h" @@ -347,30 +347,30 @@ void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, } } -// void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, -// OperationSet* consumers) const { -// OperationSet inlined_ops; +void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, + OperationSet* consumers) const { + OperationSet inlined_ops; -// for (const auto& stage : state->stages) { -// if (stage->compute_at == kInlined) { -// inlined_ops.insert(stage->op); -// } -// } -// std::function collect; + for (const auto& stage : state->stages) { + if (stage->compute_at == kInlined) { + inlined_ops.insert(stage->op); + } + } + std::function collect; -// collect = [this, &collect, &inlined_ops, &consumers](const Operation& op) { -// for (const auto& iter : operator->()->read_by.at(op)) { -// if (inlined_ops.count(iter.first)) { -// collect(iter.first); -// } else { -// consumers->insert(iter.first); -// } -// } -// }; + collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { + for (const auto& iter : operator->()->read_by.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + consumers->insert(iter.first); + } + } + }; -// consumers->clear(); -// collect(op); -// } + consumers->clear(); + collect(op); +} bool IntArrayEqual(const Array& arr1, const Array& arr2) { if (arr1.size() != arr2.size()) { @@ -547,9 +547,9 @@ void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { } } -// State ComputeDAG::GetInitState() const { -// return Downcast(operator->()->init_state); -// } +State ComputeDAG::GetInitState() const { + return Downcast(operator->()->init_state); +} ComputeDAG ComputeDAGNode::make(Array tensors) { auto node = make_object(); @@ -559,7 +559,7 @@ ComputeDAG ComputeDAGNode::make(Array tensors) { node->access_analyzer = AccessAnalyzerNode::make(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); -// node->init_state = StateNode::make(node->ops); + node->init_state = StateNode::make(node->ops); return ComputeDAG(node); } @@ -580,8 +580,8 @@ void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); v->Visit("access_analyzer", &access_analyzer); -// State s = Downcast(init_state); -// v->Visit("init_state", &s); + State s = Downcast(init_state); + v->Visit("init_state", &s); } // Implemented in multi_stage_policy.cc @@ -1075,79 +1075,79 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, // } // } -// std::pair > ComputeDAG::ReplaySteps( -// const std::vector &transform_steps, -// std::vector *stages, -// StageToAxesMap *stage_to_axes) const { -// std::vector ops; -// for (const auto& op : operator->()->ops) { -// if (!op->IsInstance()) { -// ops.push_back(op); -// } -// } +std::pair > ComputeDAG::ReplaySteps( + const std::vector &transform_steps, + std::vector *stages, + StageToAxesMap *stage_to_axes) const { + std::vector ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } -// te::Schedule schedule = te::create_schedule({ops.back()}); + te::Schedule schedule = te::create_schedule({ops.back()}); -// // init axes -// stages->reserve(operator->()->ops.size()); -// for (const auto& x : operator->()->ops) { -// const te::Stage& stage = schedule.operator[](x); -// stages->push_back(stage); -// UpdateStageAxis(stage, stage_to_axes); -// } + // init axes + stages->reserve(operator->()->ops.size()); + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages->push_back(stage); + UpdateStageAxis(stage, stage_to_axes); + } -// // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at -// // an splitted axis? + // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at + // an splitted axis? -// // Use complete rate for the study in the paper -// const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); -// double complete_rate = -1.0; -// if (complete_rate_str) { -// complete_rate = std::stod(complete_rate_str); -// } -// size_t ct = 0; + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; -// // replay history -// for (const auto& step : transform_steps) { -// if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { -// break; -// } + // replay history + for (const auto& step : transform_steps) { + if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { + break; + } -// if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else { -// LOG(FATAL) << "Invalid Step"; -// } -// } + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, &schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, &schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, &schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } + } -// return std::make_pair(schedule, operator->()->tensors); -// } + return std::make_pair(schedule, operator->()->tensors); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index c8da44fee828..9d0708a77f1c 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -148,7 +148,7 @@ class ComputeDAG: public ObjectRef { // Internal common parts for replaying steps std::pair > ReplaySteps( const std::vector& transform_steps, std::vector* stages, - StageToAxesMap* stage_to_axes) const {}; + StageToAxesMap* stage_to_axes) const; static constexpr const char* _layout_free_placeholders_key = "layout_free_placeholders"; // Internal common parts for inferring bound diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 92157edc463d..f01899c4c793 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -8,825 +8,8 @@ namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(StepNode); TVM_REGISTER_NODE_TYPE(StateNode); -inline std::string CleanName(const std::string& str) { - // to make the name valid in python code - std::string ret = str; - StrReplace(&ret, ".", "_"); - StrReplace(&ret, "@", "_"); - StrReplace(&ret, "outer", "o"); - StrReplace(&ret, "inner", "i"); - return ret; -} - -/********** Reorder **********/ -ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->after_ids = after_ids; - return ReorderStep(node); -} - -void ReorderStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - CHECK_EQ(after_ids.size(), axes.size()); - - std::vector new_axes; - new_axes.reserve(axes.size()); - for (auto i : after_ids) { - new_axes.push_back(axes[i]); - } - stage.reorder(new_axes); - (*stage_to_axes)[stage] = std::move(new_axes); -} - -std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - const te::Stage& stage = (*stages)[stage_id]; - std::stringstream ss; - - ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; - for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); - if (i != after_ids.size() - 1) { - ss << ", "; - } - } - ss << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Split **********/ -std::vector ApplySplitToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - int stage_id, - int iter_id, - const std::vector& lengths, - bool inner_to_outer) { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - std::vector outs; - if (inner_to_outer) { - IterVar outer = axes[iter_id], inner; - for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { - IterVar to_split = outer; - stage.split(to_split, lengths[i], &outer, &inner); - outs.push_back(inner); - } - outs.push_back(outer); - } else { - IterVar outer, inner = axes[iter_id]; - for (size_t i = 0; i < lengths.size(); i++) { - IterVar to_split = inner; - stage.split_by_nparts(to_split, lengths[i], &outer, &inner); - outs.push_back(outer); - } - outs.push_back(inner); - } - - std::vector new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); - if (inner_to_outer) { - new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); - } else { - new_axes.insert(new_axes.end(), outs.begin(), outs.end()); - } - new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); - - return outs; -} - -std::string PrintSplitAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - int stage_id, - int iter_id, - const std::vector& lengths, - bool inner_to_outer) { - te::Stage& stage = (*stages)[stage_id]; - auto to_split = (*stage_to_axes)[stage][iter_id]; - const auto& func_name = CleanName(stage->op->func_name()); - const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, - iter_id, lengths, inner_to_outer); - - std::stringstream ss; - int size = static_cast(lengths.size()); - if (inner_to_outer) { - for (int i = size - 1; i >= 0; i--) { - ss << CleanName(outs[size - i]->var->name_hint) << ", " - << CleanName(outs[size - i - 1]->var->name_hint) - << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) - << ", factor=" << lengths[i] << ")\n"; - to_split = outs[size - i]; - } - } else { - for (int i = 0; i < size; i++) { - ss << CleanName(outs[i]->var->name_hint) << ", " - << CleanName(outs[i + 1]->var->name_hint) - << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) - << ", nparts=" << lengths[i] << ")\n"; - to_split = outs[i + 1]; - } - } - - return ss.str(); -} - -SplitStep SplitStepNode::make(int stage_id, int iter_id, - PrimExpr extent, const std::vector& lengths, - bool inner_to_outer) { - auto node = make_object(); - node->stage_id = stage_id; - // Extent can be a unreducible expression in some special cases - if (extent->IsInstance()) { - node->extent = std::move(extent); - } - node->iter_id = iter_id; - node->lengths = lengths; - node->inner_to_outer = inner_to_outer; - return SplitStep(node); -} - -std::vector SplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes) const { - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, inner_to_outer); -} - -std::string SplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, inner_to_outer); -} - -/********** Follow Split **********/ -FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, - int src_step_id, int n_split) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_id = src_step_id; - node->n_split = n_split; - return FollowSplitStep(node); -} - -void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - - // get lengths from src step - lengths->reserve(n_split); - int j = 0; - for (; j < n_split - 1; ++j) { - lengths->push_back(ps->lengths[j]); - } - PrimExpr last_factor = 1; - for (; j < static_cast(ps->lengths.size()); ++j) { - if (ps->lengths[j].defined()) { - last_factor *= ps->lengths[j]; - } else { - last_factor = PrimExpr(); - break; - } - } - lengths->push_back(std::move(last_factor)); -} - -std::vector FollowSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -std::string FollowSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -/********** Follow Fused Split **********/ -FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_ids = src_step_ids;; - node->level = level; - node->factor_or_nparts = factor_or_nparts; - return FollowFusedSplitStep(node); -} - -PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { - PrimExpr ret(1); - - for (int src_step_id : src_step_ids) { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - if (ps->lengths[level].defined() && ret.defined()) { - ret *= ps->lengths[level]; - } else { - return PrimExpr(); - } - } - - return ret; -} - -std::vector FollowFusedSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - -std::string FollowFusedSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - - -/********** Fuse **********/ -FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->fused_ids = fused_ids; - return FuseStep(node); -} - -IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - Array to_fuse; - for (auto i : fused_ids) { - to_fuse.push_back(axes[i]); - } - IterVar fused_axis; - stage.fuse(to_fuse, &fused_axis); - std::vector new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); - new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, - axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); - - return fused_axis; -} - -std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - const auto& stage = (*stages)[stage_id]; - std::stringstream to_fuse; - - for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); - if (i != fused_ids.size() - 1) { - to_fuse << ", "; - } - } - - std::stringstream ss; - const auto& fused = ApplyToSchedule(stages, stage_to_axes); - - ss << CleanName(fused->var->name_hint) << " = s[" - << CleanName(stage->op->func_name()) << "].fuse(" - << to_fuse.str() << ")\n"; - - return ss.str(); -} - -/********** Annotation **********/ -AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->annotation = ann; - return AnnotationStep(node); -} - -void AnnotationStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - switch (annotation) { - case kUnroll: stage.unroll(axes[iter_id]); break; - case kVectorize: stage.vectorize(axes[iter_id]); break; - case kParallel: stage.parallel(axes[iter_id]); break; - case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; - case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; - case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; - case kThreadX: - if (axes[iter_id]->iter_type == kCommReduce) { - const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); - stage.bind(axes[iter_id], thread_x); - stage.set_store_predicate(thread_x->var == 0); - } else { - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); - } - break; - case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; - case kNone: break; - default: LOG(FATAL) << "Invalid Annotation " << annotation; break; - } -} - -std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& iter = (*stage_to_axes)[stage][iter_id]; - - bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; - if (bind_reduce_iter) { - ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; - } - - ss << "s[" << CleanName(stage->op->func_name()) << "]."; - switch (annotation) { - case kUnroll: ss << "unroll("; break; - case kVectorize: ss << "vectorize("; break; - case kParallel: ss << "parallel("; break; - case kVThread: - case kBlockX: - case kBlockY: - case kThreadX: - case kThreadY: ss << "bind("; break; - case kNone: break; - default: - LOG(FATAL) << "Invalid annotation " << annotation; break; - } - ss << CleanName(iter->var->name_hint); - switch (annotation) { - case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; - case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; - case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; - case kThreadX: - if (bind_reduce_iter) { - ss << ", thread_x"; - } else { - ss << ", tvm.thread_axis(\"threadIdx.x\")"; - } - break; - case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; - default: break; - } - ss << ")\n"; - - if (bind_reduce_iter) { - ss << "s[" << CleanName(stage->op->func_name()) << "]" - << ".set_store_predicate(thread_x.var.equal(0))\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute at **********/ -ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->target_stage_id = target_stage_id; - node->target_iter_id = target_iter_id; - return ComputeAtStep(node); -} - -void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const IterVar& target_axis = - (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; - stage.compute_at((*stages)[target_stage_id], target_axis); -} - -std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& target_stage = (*stages)[target_stage_id]; - - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" - << CleanName(target_stage->op->func_name()) << "], " - << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); - - ss << ")\n"; - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute Root **********/ -ComputeRootStep ComputeRootStepNode::make(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - return ComputeRootStep(node); -} - -void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_root(); -} - -std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Compute Inline **********/ -ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - return ComputeInlineStep(node); -} - -void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_inline(); -} - -std::string ComputeInlineStepNode::PrintAsPythonAPI( - std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Pack for vec **********/ -PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->vec_size = vec_size; - return PackForVecStep(node); -} - -void PackForVecStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - LOG(FATAL) << "Not implemented"; -} - -std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - LOG(FATAL) << "Not implemented"; - return ""; -} - -/********** Cache read **********/ -CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, - const std::vector& reader_stage_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - node->reader_stage_ids = reader_stage_ids; - return CacheReadStep(node); -} - -te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array readers; - for (const auto& i : reader_stage_ids) { - readers.push_back((*stages)[i]->origin_op); - } - auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); - - const auto& new_stage = (*schedule)[out->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id + 1, new_stage); - - return out; -} - -std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - auto stage = (*stages)[stage_id]; - std::vector reader_stages; - for (size_t i = 0; i < reader_stage_ids.size(); ++i) { - reader_stages.push_back((*stages)[reader_stage_ids[i]]); - } - - auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - - ss << CleanName(out->op->func_name()) << " = " - << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" - << scope_name << "\", [" - << CleanName(reader_stages[0]->op->func_name()); - for (size_t i = 1; i < reader_stage_ids.size(); ++i) { - ss << ", " << CleanName(reader_stages[i]->op->func_name()); - } - ss << "])\n"; - - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) - << ".op.axis)\n"; - - return ss.str(); -} - -/********** Cache write **********/ -CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - return CacheWriteStep(node); -} - -Array CacheWriteStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array tensor_array; - // If the target stage has multi outputs, TVM requires to cache_write - // all of them or schedule.cache_write will raise an error - for (auto i = 0; i < stage->op->num_outputs(); ++i) { - tensor_array.push_back(stage->origin_op.output(i)); - } - auto outs = schedule->cache_write(tensor_array, scope_name); - - UpdateStageAxis(stage, stage_to_axes); - // Even if there is multi outputs, TVM schedule only generate one - // new stage - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - te::Stage stage = (*stages)[stage_id]; - - auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()) << ", "; - } - ss << "= " << "s.cache_write([" - << CleanName(stage->op.output(0)->op->name); - for (auto i = 1; i < stage->op->num_outputs(); ++i) { - ss << ", " << CleanName(stage->op.output(i)->op->name); - } - ss << "], \"" << scope_name << "\")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) - << ".op.reduce_axis)\n"; - } - - return ss.str(); -} - -/********** Pragma **********/ -PragmaStep PragmaStepNode::make(int stage_id, int iter_id, - std::string pragma_type) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->pragma_type = std::move(pragma_type); - return PragmaStep(node); -} - -void PragmaStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - stage.pragma(axes[iter_id], "auto_unroll_max_step", value); - stage.pragma(axes[iter_id], "unroll_explicit", true); - } else { - stage.pragma(axes[iter_id], pragma_type); - } -} - -std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"unroll_explicit\", True)\n"; - } else { - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" - << pragma_type << "\")\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Rfactor **********/ -RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor_iter_id = factor_iter_id; - return RfactorStep(node); -} - -Array RfactorStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - const auto& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - const te::Tensor& tensor = stage->origin_op.output(0); - const IterVar& axis = axes[iter_id]; - auto outs = schedule->rfactor(tensor, axis, factor_iter_id); - - UpdateStageAxis(stage, stage_to_axes); - - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); - const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); - - const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()); - if (i != outs.size() - 1) { - ss << ", "; - } - } - ss << " = " << "s.rfactor(" - << tensor_name << ", " - << axis_name << ", " - << factor_iter_id << ")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) - << ".op.reduce_axis)\n"; - } - - const auto& output = (*stages)[stage_id + 1]->op.output(0); - const auto& iters = output->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) - << "].op.axis)" - << " + " << "tuple(s[" << CleanName(output->op->func_name()) - << "].op.reduce_axis)\n"; - - return ss.str(); -} - -/********** StorageAlign **********/ - -StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, - int factor, int offset) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor = factor; - node->offset = offset; - return StorageAlignStep(node); -} - -void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - stage.storage_align(axes[iter_id], factor, offset); -} - -std::string StorageAlignStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " - << factor << ", " << offset << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -// Maker for other classes -Iterator IteratorNode::make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { - auto node = make_object(); - node->name = std::move(name); - node->range = std::move(range); - node->iter_type = iter_type; - node->annotation = annotation; - if (ori_iters != nullptr) { - node->ori_iters = *ori_iters; - } - return Iterator(node); -} - Stage StageNode::make(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { @@ -854,8 +37,10 @@ Stage StageNode::make(te::Operation op) { return Stage(node); } -Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { +Stage StageNode::make(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -866,8 +51,10 @@ Stage StageNode::make(te::Operation op, StageType op_type, const std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { +Stage StageNode::make(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -927,8 +114,9 @@ void State::reorder(int stage_id, const std::vector& order) { DoReorderStep(step); } -std::vector State::split(int stage_id, - const Iterator& it, const std::vector& lengths, bool inner_to_outer) { +std::vector State::split(int stage_id, const Iterator& it, + const std::vector& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), @@ -949,8 +137,9 @@ std::vector State::follow_split(int stage_id, } -std::vector State::follow_fused_split(int stage_id, const Iterator& it, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { +std::vector State::follow_fused_split( + int stage_id, const Iterator& it, const std::vector& src_step_ids, + int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, @@ -970,24 +159,24 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), - kVectorize); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), kVectorize); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), - kParallel); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), - kUnroll); + AnnotationStep step = AnnotationStepNode::make(stage_id, + GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -1002,7 +191,8 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { return DoAnnotationStep(step); } -void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { +void State::compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); @@ -1022,7 +212,8 @@ void State::compute_inline(int stage_id) { return DoComputeInlineStep(step); } -void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { +void State::pack_for_vec(int stage_id, const Iterator& target_iter, + int vec_size) { const Stage& stage = operator->()->stages[stage_id]; PackForVecStep step = PackForVecStepNode::make(stage_id, GetIndex(stage->iters, target_iter), vec_size); @@ -1044,8 +235,10 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, } int State::cache_read(int stage_id, const std::string& scope_name, - const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { - CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); + const std::vector& reader_stage_ids, + const ComputeDAG& task_dag) { + CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, + reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } @@ -1057,7 +250,8 @@ int State::cache_write(int stage_id, const std::string& scope_name, return DoCacheWriteStep(step, task_dag); } -void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { +void State::pragma(int stage_id, const Iterator& it, + const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); @@ -1068,7 +262,8 @@ void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_t int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; - RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); + RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), + factor_iter_id); CopyOnWrite()->transform_steps.push_back(step); return DoRfactorStep(step, task_dag); } @@ -1093,15 +288,16 @@ void State::DoReorderStep(const ReorderStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(iters), stage->compute_at, + std::move(iters), + stage->compute_at, stage->auto_unroll_max_step, stage->storage_offset); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -std::vector State::DoSplitStepCommon(int stage_id, int iter_id, - const std::vector& lengths, - bool inner_to_outer) { +std::vector State::DoSplitStepCommon( + int stage_id, int iter_id, const std::vector& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; size_t old_iter_size = stage->iters.size(); @@ -1142,24 +338,29 @@ std::vector State::DoSplitStepCommon(int stage_id, int iter_id, range = Range::make_by_min_extent(tosplit_min, tosplit_extent); } if (inner_to_outer) { - outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); + outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, + kNone)); std::reverse(outs.begin(), outs.end()); } else { - outs.push_back(IteratorNode::make(it->name + "." + std::to_string(lengths.size()), - range, it->iter_type, kNone)); + outs.push_back(IteratorNode::make( + it->name + "." + std::to_string(lengths.size()), range, it->iter_type, + kNone)); } std::vector new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, stage->iters.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, + stage->iters.end()); StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, stage->storage_offset); - // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping std::vector from_iters; std::vector to_iters; for (size_t i = iter_id; i < old_iter_size; ++i) { @@ -1181,9 +382,12 @@ std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { return DoSplitStepCommon(step->stage_id, step->iter_id, lengths, true); } -std::vector State::DoFollowFusedSplitStep(const FollowFusedSplitStep& step) { - const PrimExpr& length = step->ExtractSplitLength(operator->()->transform_steps); - return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); +std::vector State::DoFollowFusedSplitStep( + const FollowFusedSplitStep& step) { + const PrimExpr& length = step->ExtractSplitLength( + operator->()->transform_steps); + return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, + step->factor_or_nparts); } Iterator State::DoFuseStep(const FuseStep& step) { @@ -1202,8 +406,10 @@ Iterator State::DoFuseStep(const FuseStep& step) { } if (i != step->fused_ids.size() - 1) { - const auto& iter_to_attached_stage = operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, step->fused_ids[i])) + const auto& iter_to_attached_stage = + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, + step->fused_ids[i])) != iter_to_attached_stage.end()) { LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " "that have been attached by some stages"; @@ -1233,20 +439,23 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::make_by_min_extent(0, new_extent); } - Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); + Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, + &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front()); + stage->iters.begin() + step->fused_ids.front()); new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, - stage->iters.end()); + new_iters.insert(new_iters.end(), + stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, stage->storage_offset); - // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping std::vector from_iters; std::vector to_iters; const int begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); @@ -1282,15 +491,18 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; // after compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + // If we do want to know the accurate lengths, we can call + // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { size_t s = it->name.size(); - if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && it->name[s-1] <= '4') { - // We use a dangerous heuristic rule here : For multi level splitted iterators, we assume - // their length does not change after compute_at. - // Reason: These iterators are generated in MultiStagePolicy by multi level tiling, they will - // be carefully compute_at their consumers. In this case, their lengths do not change. + if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && + it->name[s-1] <= '4') { + // We use a dangerous heuristic rule here : For multi level splitted + // iterators, we assume their length does not change after compute_at. + // Reason: These iterators are generated in MultiStagePolicy by multi + // level tiling, they will be carefully compute_at their consumers. + // In this case, their lengths do not change. // We do this to keep the AnnotateCPU pass to annotate more efficiently. new_iters.push_back(it); } else { @@ -1303,14 +515,16 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, stage->auto_unroll_max_step, stage->storage_offset); - pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); + pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, + step->target_iter_id); } void State::DoComputeRootStep(const ComputeRootStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; // after compute_root, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + // If we do want to know the accurate lengths, we can call + // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, @@ -1331,7 +545,8 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { StateNode* pstate = CopyOnWrite(); // CHECK the validity of compute_inline - const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; + const auto& iter_to_attached_stages = + pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) << "Invalid compute_inline: Because there are some other stages " @@ -1346,15 +561,18 @@ void State::DoPackForVecStep(const PackForVecStep& step) { LOG(FATAL) << "Not implemented"; } -// Common part for steps that add new stages (e.g. CacheReadStep, CacheWriteStep, RfactorStep) -void AddStageModificationSteps(size_t step_id, const std::vector& transform_steps, - std::vector* replay_steps) { +// Common part for steps that add new stages +// (e.g. CacheReadStep, CacheWriteStep, RfactorStep) +void AddStageModificationSteps(size_t step_id, + const std::vector& transform_steps, std::vector* replay_steps) { const Step& step = transform_steps[step_id]; - if (step->IsInstance() || step->IsInstance()) { + if (step->IsInstance() || + step->IsInstance()) { replay_steps->push_back(step); } else if (step->IsInstance()) { // add FuseStepNode required by rfactor - if (step_id >= 2 && transform_steps[step_id - 2]->IsInstance()) { + if (step_id >= 2 && + transform_steps[step_id - 2]->IsInstance()) { const Step& fuse_step = transform_steps[step_id - 2]; if (fuse_step->stage_id == step->stage_id) { replay_steps->push_back(fuse_step); @@ -1406,7 +624,12 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { break; } } + + int last_dag_op_size = pstate->task_dag.defined() ? + pstate->task_dag->ops.size() : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; + CHECK_GE(added_ops, 1); // target -> target_compute + target // Assume target stage has never been applied any steps before cache_write @@ -1415,11 +638,24 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { StageNode::make(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); - for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + int next_stage_id = step->stage_id + 2; + // Notice: added_ops should actually assert to be 1 + // branch of 2 here is somehow a hack to TVM's cache_write bug with + // multi outputs, see test/cpp/ansor_test.cc: CacheReadWrite test + // for more information + // TODO(jcf94): Fix this + if (added_ops == 2) { + pstate->stages.insert(pstate->stages.begin() + next_stage_id, + StageNode::make(operator->()->task_dag->ops[next_stage_id])); + next_stage_id++; + } else if (added_ops > 2) { + LOG(ERROR) << "Unexpected behavior of CacheWrite."; + } + for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, added_ops); return step->stage_id; } @@ -1530,8 +766,8 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { } -void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, - bool delete_trivial_loop) { +void PrintStage(std::ostream* os, int stage_id, const StateNode* state, + size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; if (stage->auto_unroll_max_step != 0) { @@ -1553,7 +789,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b for (size_t i = 0; i < stage->iters.size(); ++i) { const Iterator& iter = stage->iters[i]; - if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { + if (!(delete_trivial_loop && iter->range.defined() && + is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } @@ -1569,7 +806,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b case kThreadY: *os << "gpu.threadIdx.y "; break; } if (iter->range.defined()) { - *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")" << "\n"; + *os << iter->name << " (" << iter->range->min << "," + << iter->range->extent << ")" << "\n"; } else { *os << iter->name << " (None)" << "\n"; } @@ -1582,7 +820,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); if (pair != state->attach_map->iter_to_attached_stages.end()) { for (const auto& attach_stage_id : pair->second) { - PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); + PrintStage(os, attach_stage_id, state, base_indent + indent, + delete_trivial_loop); } } } @@ -1594,7 +833,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b *os << stage->op->func_name() << " = ...\n"; } -void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { +void PrintState(std::ostream* os, const StateNode* node, + bool delete_trivial_loop) { // Gather placeholders std::vector placeholders; for (const auto& stage : node->stages) { @@ -1633,7 +873,8 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); // delete the current entry of stage @@ -1641,7 +882,8 @@ void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_i // store the new relation IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, + target_iter_id); pnode->iter_to_attached_stages[iter_key].push_back(stage_id); } diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 3ffe8a7feafb..dd56e267c0a0 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -19,36 +19,18 @@ #ifndef TVM_ANSOR_LOOP_STATE_H_ #define TVM_ANSOR_LOOP_STATE_H_ -// #include -// #include -// #include -#include #include #include #include #include #include -#include "expr_hasher.h" -#include "utils.h" -#include "compute_dag.h" +#include "transform_step.h" namespace tvm { namespace ansor { using namespace tvm::tir; -enum IteratorType { - kSpace, // spatial iterator - kReduce, // reduction iterator - kMixed, // fused spatial and reduction iterator - kSpecial // special iterator (e.g. virtual root iterator) -}; - -enum IteratorAnnotation { - kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY -}; - enum StageType { kPlaceholder, kCompute }; @@ -59,29 +41,7 @@ enum ComputeAtType { kIter, // compute at some iterator }; -/* Iterator and Stage */ -class Iterator; class Stage; class State; - -/*! - * \brief An for loop iterator - * Similar to tvm::IterVar in `include/expr.h` - */ -class IteratorNode : public Object { - public: - std::string name; - Range range; // domain of for loop range - IteratorType iter_type; - IteratorAnnotation annotation; - std::vector ori_iters; - - static Iterator make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); - - static constexpr const char *_type_key = "ansor.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); +class Stage; class State; /*! * \brief A stage in the compute declaration @@ -97,389 +57,20 @@ class StageNode : public Object { int storage_offset; static Stage make(te::Operation op); - static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); - static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); + static Stage make(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset); + static Stage make(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset); static constexpr const char *_type_key = "ansor.Stage"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; TVM_DEFINE_COW_NODE_REF(Stage, ObjectRef, StageNode); - -/*! \brief The base class for a transformation step */ -class StepNode: public Object { - public: - int stage_id; - - // Print step as equivalent python schedule API - virtual std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const = 0; - - static constexpr const char* _type_key = "ansor.Step"; - TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); -}; -TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); - -/* - * Note on how to add a new transform step - * - * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function - * in FuseStepNode::make(...) loop_state.cc - * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. - * - In these two functions you need to lower this step with tvm's schedule API - * 3. Implement State::fuse and State::DoFuseStep. - * - In these two functions you need to incrementally update all data structures in State with - * CopyOnWrite style - * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. - * 5. Add serialization support in `struct Handler >` - * (in serialization.cc) - * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) - */ - -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class PackForVecStep; class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class AttachMap; - -class ReorderStepNode: public StepNode { - public: - std::vector after_ids; - - static ReorderStep make(int stage_id, const std::vector& after_ids); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ReorderStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); - - -class SplitStepNode: public StepNode { - public: - int iter_id; - PrimExpr extent; // the extent of the axis to split - std::vector lengths; // The split factors - bool inner_to_outer; - - static SplitStep make(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, - bool inner_to_outer); - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.SplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); - -// Similar to SplitStepNode, but use split factor from another step(i.e. Follow another split step) -class FollowSplitStepNode: public StepNode { - public: - int iter_id; - int src_step_id; - int n_split; - - static FollowSplitStep make(int stage_id, int iter_id, - int src_step_id, int n_split); - - void ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); - - -// Similar to FollowSplitStep, but use split factors from multiple steps -// This can be used for the split in cooperative fetching. -class FollowFusedSplitStepNode: public StepNode { - public: - int iter_id; - std::vector src_step_ids; - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - - static FollowFusedSplitStep make(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts); - - PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); - - -class FuseStepNode: public StepNode { - public: - std::vector fused_ids; - - static FuseStep make(int stage_id, const std::vector& fused_ids); - - IterVar ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FuseStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); - - -class AnnotationStepNode: public StepNode { - public: - int iter_id; - IteratorAnnotation annotation; - - static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); - - -class ComputeAtStepNode: public StepNode { - public: - int target_stage_id; - int target_iter_id; - - static ComputeAtStep make(int stage_id, int target_stage_id, int target_iter_id); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); - - -class ComputeRootStepNode: public StepNode { - public: - static ComputeRootStep make(int stage_id); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); - - -class ComputeInlineStepNode: public StepNode { - public: - static ComputeInlineStep make(int stage_id); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); - -class PackForVecStepNode: public StepNode { - public: - int iter_id; - int vec_size; - - static PackForVecStep make(int stage_id, int iter_id, int vec_size); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PackForVecStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); - - -/*! \brief Apply cache_read to a stage - * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ -class CacheReadStepNode: public StepNode { - public: - std::string scope_name; - std::vector reader_stage_ids; - - static CacheReadStep make(int stage_id, std::string scope_name, - const std::vector& reader_stage_id); - - te::Tensor ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheReadStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); - - -/*! \brief Apply cache_write to a stage - * TVM Api: te::Schedule::cache_write(tensor, scope) - * This step will cache_write all output tensors of target stage */ -class CacheWriteStepNode: public StepNode { - public: - std::string scope_name; - - static CacheWriteStep make(int stage_id, std::string scope_name); - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheWriteStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); - -/*! \brief Add pragma to a specific iterator */ -class PragmaStepNode: public StepNode { - public: - int iter_id; - std::string pragma_type; - - static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PragmaStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); - -/*! \brief Factor a reduction axis - * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ -class RfactorStepNode: public StepNode { - public: - int iter_id; - int factor_iter_id; - - static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.RfactorStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); - -class StorageAlignStepNode: public StepNode { - public: - int iter_id; - int factor; - int offset; - - static StorageAlignStep make(int stage_id, int iter_id, int factor, - int offset); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.StorageAlignStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); - /*! \brief stores the compute_at relation between stages */ class AttachMapNode: public Object { public: @@ -523,13 +114,13 @@ class AttachMap : public ObjectRef { class StateNode: public Object { public: std::vector stages; // Current stages and loop structures - std::vector transform_steps; // History transformation steps to reach this state + std::vector transform_steps; // History transformation steps bool complete; // Indicate whether this state has unfilled tile sizes AttachMap attach_map; // stores the compute_at relation between stages - ObjectRef aux_info; // Used to store any auxiliary info about this state + ObjectRef aux_info; // Used to store any auxiliary info about this state ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. - // The default value is an empty NodeRef - // (means no modification to the DAG) + // The default value is an empty NodeRef + // (means no modification to the DAG) void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("complete", &complete); @@ -539,7 +130,8 @@ class StateNode: public Object { static State make_empty_state(); static State make(const Array& ops); static State make(const std::vector& stages, - const std::vector& transform_steps, bool complete, ObjectRef aux_info); + const std::vector& transform_steps, bool complete, + ObjectRef aux_info); static constexpr const char* _type_key = "ansor.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); @@ -556,7 +148,8 @@ class State : public ObjectRef { std::vector follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); std::vector follow_fused_split(int stage_id, const Iterator& it, - const std::vector& src_step_ids, int level, bool factor_or_nparts); + const std::vector& src_step_ids, + int level, bool factor_or_nparts); Iterator fuse(int stage_id, const std::vector& iters); Iterator vectorize(int stage_id, const Iterator& it); Iterator parallel(int stage_id, const Iterator& it); @@ -564,7 +157,8 @@ class State : public ObjectRef { // Valide thread_type: kVThread, kBlockX, kThreadX, kThreadY Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); - void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + void compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); void pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size); @@ -578,7 +172,8 @@ class State : public ObjectRef { const ComputeDAG& task_dag); void storage_align(int stage_id, const Iterator& it, int factor, int offset); - /* We separate these functions out, so you can call them for replay easily given history steps */ + // We separate these functions out, + // so you can call them for replay easily given history steps void DoReorderStep(const ReorderStep& step); std::vector DoSplitStep(const SplitStep& step); std::vector DoFollowSplitStep(const FollowSplitStep& step); @@ -596,7 +191,9 @@ class State : public ObjectRef { void DoStorageAlignStep(const StorageAlignStep& step); /* Do transform steps - * Note: The following function only change loop state. They do not change transform_history. */ + * Note: The following function only change loop state. + * They do not change transform_history. + */ void DoStep(const Step& step, const ComputeDAG& dag); void DoSteps(const std::vector& step, const ComputeDAG& dag); @@ -620,98 +217,6 @@ class State : public ObjectRef { // Hash and equal function for State, Stage, Iterator and Step namespace std { -template <> -struct hash<::tvm::ansor::Step> { - std::size_t operator()(const ::tvm::ansor::Step& step) const { - if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { - return ::dmlc::HashCombine(1, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->after_ids)); - } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { - size_t ret = ::dmlc::HashCombine(2, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->inner_to_outer))); - for (const auto& len : ps->lengths) { - if (len.defined()) { - auto pint = len.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); - } else { - ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number - } - return ret; - } - } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { - return ::dmlc::HashCombine(3, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->src_step_id), - ps->n_split)))); - } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { - return ::dmlc::HashCombine(4, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), - ::dmlc::HashCombine(std::hash()(ps->level), - ps->factor_or_nparts))))); - } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine(5, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->fused_ids)); - } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { - return ::dmlc::HashCombine(6, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - static_cast(ps->annotation)))); - } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { - return ::dmlc::HashCombine(7, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->target_stage_id), - ps->target_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { - return ::dmlc::HashCombine(8, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { - return ::dmlc::HashCombine(9, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { - return ::dmlc::HashCombine(10, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->vec_size))); - } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { - return ::dmlc::HashCombine(11, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->scope_name), - ps->reader_stage_ids))); - } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { - return ::dmlc::HashCombine(12, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->scope_name)); - } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { - return ::dmlc::HashCombine(13, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->pragma_type))); - } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { - return ::dmlc::HashCombine(14, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->factor_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { - return ::dmlc::HashCombine(15, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->factor), - ps->offset)))); - } else { - LOG(FATAL) << "Invalid step"; - } - return 0; - } -}; - template <> struct hash<::tvm::ansor::State> { std::size_t operator()(const ::tvm::ansor::State& state) const { diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc new file mode 100644 index 000000000000..8cd8233ae9be --- /dev/null +++ b/src/ansor/transform_step.cc @@ -0,0 +1,820 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "transform_step.h" +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); + +/********** Reorder **********/ +ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->after_ids = after_ids; + return ReorderStep(node); +} + +void ReorderStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + CHECK_EQ(after_ids.size(), axes.size()); + + std::vector new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); +} + +std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const te::Stage& stage = (*stages)[stage_id]; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Split **********/ +std::vector ApplySplitToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + std::vector outs; + if (inner_to_outer) { + IterVar outer = axes[iter_id], inner; + for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { + IterVar to_split = outer; + stage.split(to_split, lengths[i], &outer, &inner); + outs.push_back(inner); + } + outs.push_back(outer); + } else { + IterVar outer, inner = axes[iter_id]; + for (size_t i = 0; i < lengths.size(); i++) { + IterVar to_split = inner; + stage.split_by_nparts(to_split, lengths[i], &outer, &inner); + outs.push_back(outer); + } + outs.push_back(inner); + } + + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); + if (inner_to_outer) { + new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); + } else { + new_axes.insert(new_axes.end(), outs.begin(), outs.end()); + } + new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return outs; +} + +std::string PrintSplitAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + auto to_split = (*stage_to_axes)[stage][iter_id]; + const auto& func_name = CleanName(stage->op->func_name()); + const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, + iter_id, lengths, inner_to_outer); + + std::stringstream ss; + int size = static_cast(lengths.size()); + if (inner_to_outer) { + for (int i = size - 1; i >= 0; i--) { + ss << CleanName(outs[size - i]->var->name_hint) << ", " + << CleanName(outs[size - i - 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", factor=" << lengths[i] << ")\n"; + to_split = outs[size - i]; + } + } else { + for (int i = 0; i < size; i++) { + ss << CleanName(outs[i]->var->name_hint) << ", " + << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; + to_split = outs[i + 1]; + } + } + + return ss.str(); +} + +SplitStep SplitStepNode::make(int stage_id, int iter_id, + PrimExpr extent, const std::vector& lengths, + bool inner_to_outer) { + auto node = make_object(); + node->stage_id = stage_id; + // Extent can be a unreducible expression in some special cases + if (extent->IsInstance()) { + node->extent = std::move(extent); + } + node->iter_id = iter_id; + node->lengths = lengths; + node->inner_to_outer = inner_to_outer; + return SplitStep(node); +} + +std::vector SplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +std::string SplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +/********** Follow Split **********/ +FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, + int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + return FollowSplitStep(node); +} + +void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // get lengths from src step + lengths->reserve(n_split); + int j = 0; + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j].defined()) { + last_factor *= ps->lengths[j]; + } else { + last_factor = PrimExpr(); + break; + } + } + lengths->push_back(std::move(last_factor)); +} + +std::vector FollowSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +std::string FollowSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids;; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + return FollowFusedSplitStep(node); +} + +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + if (ps->lengths[level].defined() && ret.defined()) { + ret *= ps->lengths[level]; + } else { + return PrimExpr(); + } + } + + return ret; +} + +std::vector FollowFusedSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + +std::string FollowFusedSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + + +/********** Fuse **********/ +FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->fused_ids = fused_ids; + return FuseStep(node); +} + +IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + Array to_fuse; + for (auto i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, + axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return fused_axis; +} + +std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" + << CleanName(stage->op->func_name()) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + +/********** Annotation **********/ +AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->annotation = ann; + return AnnotationStep(node); +} + +void AnnotationStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case kUnroll: stage.unroll(axes[iter_id]); break; + case kVectorize: stage.vectorize(axes[iter_id]); break; + case kParallel: stage.parallel(axes[iter_id]); break; + case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; + case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; + case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; + case kThreadX: + if (axes[iter_id]->iter_type == kCommReduce) { + const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); + stage.bind(axes[iter_id], thread_x); + stage.set_store_predicate(thread_x->var == 0); + } else { + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); + } + break; + case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; + case kNone: break; + default: LOG(FATAL) << "Invalid Annotation " << annotation; break; + } +} + +std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; + if (bind_reduce_iter) { + ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; + } + + ss << "s[" << CleanName(stage->op->func_name()) << "]."; + switch (annotation) { + case kUnroll: ss << "unroll("; break; + case kVectorize: ss << "vectorize("; break; + case kParallel: ss << "parallel("; break; + case kVThread: + case kBlockX: + case kBlockY: + case kThreadX: + case kThreadY: ss << "bind("; break; + case kNone: break; + default: + LOG(FATAL) << "Invalid annotation " << annotation; break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; + case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; + case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; + case kThreadX: + if (bind_reduce_iter) { + ss << ", thread_x"; + } else { + ss << ", tvm.thread_axis(\"threadIdx.x\")"; + } + break; + case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; + default: break; + } + ss << ")\n"; + + if (bind_reduce_iter) { + ss << "s[" << CleanName(stage->op->func_name()) << "]" + << ".set_store_predicate(thread_x.var.equal(0))\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute at **********/ +ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; + return ComputeAtStep(node); +} + +void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const IterVar& target_axis = + (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; + stage.compute_at((*stages)[target_stage_id], target_axis); +} + +std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" + << CleanName(target_stage->op->func_name()) << "], " + << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); + + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute Root **********/ +ComputeRootStep ComputeRootStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeRootStep(node); +} + +void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_root(); +} + +std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Compute Inline **********/ +ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeInlineStep(node); +} + +void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_inline(); +} + +std::string ComputeInlineStepNode::PrintAsPythonAPI( + std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Pack for vec **********/ +PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->vec_size = vec_size; + return PackForVecStep(node); +} + +void PackForVecStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + LOG(FATAL) << "Not implemented"; +} + +std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + LOG(FATAL) << "Not implemented"; + return ""; +} + +/********** Cache read **********/ +CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, + const std::vector& reader_stage_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + node->reader_stage_ids = reader_stage_ids; + return CacheReadStep(node); +} + +te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array readers; + for (const auto& i : reader_stage_ids) { + readers.push_back((*stages)[i]->origin_op); + } + auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); + + const auto& new_stage = (*schedule)[out->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id + 1, new_stage); + + return out; +} + +std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + auto stage = (*stages)[stage_id]; + std::vector reader_stages; + for (size_t i = 0; i < reader_stage_ids.size(); ++i) { + reader_stages.push_back((*stages)[reader_stage_ids[i]]); + } + + auto out = ApplyToSchedule(stages, stage_to_axes, schedule); + + ss << CleanName(out->op->func_name()) << " = " + << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" + << scope_name << "\", [" + << CleanName(reader_stages[0]->op->func_name()); + for (size_t i = 1; i < reader_stage_ids.size(); ++i) { + ss << ", " << CleanName(reader_stages[i]->op->func_name()); + } + ss << "])\n"; + + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)\n"; + + return ss.str(); +} + +/********** Cache write **********/ +CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + return CacheWriteStep(node); +} + +Array CacheWriteStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array tensor_array; + // If the target stage has multi outputs, TVM requires to cache_write + // all of them or schedule.cache_write will raise an error + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + tensor_array.push_back(stage->origin_op.output(i)); + } + auto outs = schedule->cache_write(tensor_array, scope_name); + + UpdateStageAxis(stage, stage_to_axes); + // Even if there is multi outputs, TVM schedule only generate one + // new stage + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + te::Stage stage = (*stages)[stage_id]; + + auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()) << ", "; + } + ss << "= " << "s.cache_write([" + << CleanName(stage->op.output(0)->op->name); + for (auto i = 1; i < stage->op->num_outputs(); ++i) { + ss << ", " << CleanName(stage->op.output(i)->op->name); + } + ss << "], \"" << scope_name << "\")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + return ss.str(); +} + +/********** Pragma **********/ +PragmaStep PragmaStepNode::make(int stage_id, int iter_id, + std::string pragma_type) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->pragma_type = std::move(pragma_type); + return PragmaStep(node); +} + +void PragmaStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + stage.pragma(axes[iter_id], "auto_unroll_max_step", value); + stage.pragma(axes[iter_id], "unroll_explicit", true); + } else { + stage.pragma(axes[iter_id], pragma_type); + } +} + +std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"auto_unroll_max_step\", " << value << ")\n"; + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"unroll_explicit\", True)\n"; + } else { + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" + << pragma_type << "\")\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Rfactor **********/ +RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor_iter_id = factor_iter_id; + return RfactorStep(node); +} + +Array RfactorStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + const auto& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + const te::Tensor& tensor = stage->origin_op.output(0); + const IterVar& axis = axes[iter_id]; + auto outs = schedule->rfactor(tensor, axis, factor_iter_id); + + UpdateStageAxis(stage, stage_to_axes); + + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); + const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); + + const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()); + if (i != outs.size() - 1) { + ss << ", "; + } + } + ss << " = " << "s.rfactor(" + << tensor_name << ", " + << axis_name << ", " + << factor_iter_id << ")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + const auto& output = (*stages)[stage_id + 1]->op.output(0); + const auto& iters = output->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.axis)" + << " + " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.reduce_axis)\n"; + + return ss.str(); +} + +/********** StorageAlign **********/ + +StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, + int factor, int offset) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor = factor; + node->offset = offset; + return StorageAlignStep(node); +} + +void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + stage.storage_align(axes[iter_id], factor, offset); +} + +std::string StorageAlignStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << factor << ", " << offset << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +// Maker for other classes +Iterator IteratorNode::make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + return Iterator(node); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h new file mode 100644 index 000000000000..9b430be99bd3 --- /dev/null +++ b/src/ansor/transform_step.h @@ -0,0 +1,551 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/transform_step.h + * \brief Data structures for loop transformations + + * Basically this is a simplified TVM IR with schedule primitives. + * We don't use the existing TVM IR because + * 1. We want fast incremental change to the loop structures + * 2. We want serializable history for replay and backtracking + * 3. We want simplified IR for easy and clean feature extraction + * 4. We may create some Macro schedule primitives + + * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. + * Because we share a lot common objects during search, the transformation is + * implemented in copy on write style. All objects are immutable, which is + * similar to TVM IR. + */ + +#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ +#define TVM_ANSOR_TRANSFORM_STEP_H_ + +#include +#include +#include +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +inline std::string CleanName(const std::string& str) { + // to make the name valid in python code + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY +}; + +class Iterator; + +/*! + * \brief An for loop iterator + * Similar to tvm::IterVar in `include/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; // domain of for loop range + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; + + static Iterator make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr); + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); + +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); + +/* + * Note on how to add a new transform step + * + * Take fuse for example: + * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function + * in FuseStepNode::make(...) loop_state.cc + * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. + * - In these two functions you need to lower this step with tvm's schedule API + * 3. Implement State::fuse and State::DoFuseStep. + * - In these two functions you need to incrementally update all data structures in State with + * CopyOnWrite style + * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 5. Add serialization support in `struct Handler >` + * (in serialization.cc) + * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + */ + +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class PackForVecStep; class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; +class AttachMap; + +class ReorderStepNode: public StepNode { + public: + std::vector after_ids; + + static ReorderStep make(int stage_id, const std::vector& after_ids); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); + + +class SplitStepNode: public StepNode { + public: + int iter_id; + PrimExpr extent; // the extent of the axis to split + std::vector lengths; // The split factors + bool inner_to_outer; + + static SplitStep make(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.SplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); + +// Similar to SplitStepNode, but use split factor from another step +// (i.e. Follow another split step) +class FollowSplitStepNode: public StepNode { + public: + int iter_id; + int src_step_id; + int n_split; + + static FollowSplitStep make(int stage_id, int iter_id, + int src_step_id, int n_split); + + void ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); + + +// Similar to FollowSplitStep, but use split factors from multiple steps +// This can be used for the split in cooperative fetching. +class FollowFusedSplitStepNode: public StepNode { + public: + int iter_id; + std::vector src_step_ids; + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + static FollowFusedSplitStep make(int stage_id, int iter_id, + const std::vector& src_step_ids, + int level, bool factor_or_nparts); + + PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + + +class FuseStepNode: public StepNode { + public: + std::vector fused_ids; + + static FuseStep make(int stage_id, const std::vector& fused_ids); + + IterVar ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); + + +class AnnotationStepNode: public StepNode { + public: + int iter_id; + IteratorAnnotation annotation; + + static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); + + +class ComputeAtStepNode: public StepNode { + public: + int target_stage_id; + int target_iter_id; + + static ComputeAtStep make(int stage_id, int target_stage_id, + int target_iter_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); + + +class ComputeRootStepNode: public StepNode { + public: + static ComputeRootStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); + + +class ComputeInlineStepNode: public StepNode { + public: + static ComputeInlineStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); + +class PackForVecStepNode: public StepNode { + public: + int iter_id; + int vec_size; + + static PackForVecStep make(int stage_id, int iter_id, int vec_size); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PackForVecStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); + + +/*! \brief Apply cache_read to a stage + * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ +class CacheReadStepNode: public StepNode { + public: + std::string scope_name; + std::vector reader_stage_ids; + + static CacheReadStep make(int stage_id, std::string scope_name, + const std::vector& reader_stage_id); + + te::Tensor ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheReadStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); + + +/*! \brief Apply cache_write to a stage + * TVM Api: te::Schedule::cache_write(tensor, scope) + * This step will cache_write all output tensors of target stage */ +class CacheWriteStepNode: public StepNode { + public: + std::string scope_name; + + static CacheWriteStep make(int stage_id, std::string scope_name); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheWriteStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); + +/*! \brief Add pragma to a specific iterator */ +class PragmaStepNode: public StepNode { + public: + int iter_id; + std::string pragma_type; + + static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PragmaStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); + +/*! \brief Factor a reduction axis + * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ +class RfactorStepNode: public StepNode { + public: + int iter_id; + int factor_iter_id; + + static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.RfactorStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); + +class StorageAlignStepNode: public StepNode { + public: + int iter_id; + int factor; + int offset; + + static StorageAlignStep make(int stage_id, int iter_id, int factor, + int offset); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.StorageAlignStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); + +} // namespace ansor +} // namespace tvm + +// Hash and equal function for State, Stage, Iterator and Step +namespace std { + +template <> +struct hash<::tvm::ansor::Step> { + std::size_t operator()(const ::tvm::ansor::Step& step) const { + if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { + return ::dmlc::HashCombine(1, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->after_ids)); + } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { + size_t ret = ::dmlc::HashCombine(2, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->inner_to_outer))); + for (const auto& len : ps->lengths) { + if (len.defined()) { + auto pint = len.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); + } else { + ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number + } + return ret; + } + } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { + return ::dmlc::HashCombine(3, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->src_step_id), + ps->n_split)))); + } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { + return ::dmlc::HashCombine(4, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), + ::dmlc::HashCombine(std::hash()(ps->level), + ps->factor_or_nparts))))); + } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { + return ::dmlc::HashCombine(5, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->fused_ids)); + } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { + return ::dmlc::HashCombine(6, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + static_cast(ps->annotation)))); + } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { + return ::dmlc::HashCombine(7, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->target_stage_id), + ps->target_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { + return ::dmlc::HashCombine(8, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { + return ::dmlc::HashCombine(9, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { + return ::dmlc::HashCombine(10, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->vec_size))); + } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { + return ::dmlc::HashCombine(11, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->scope_name), + ps->reader_stage_ids))); + } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { + return ::dmlc::HashCombine(12, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->scope_name)); + } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { + return ::dmlc::HashCombine(13, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->pragma_type))); + } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { + return ::dmlc::HashCombine(14, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->factor_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->factor), + ps->offset)))); + } else { + LOG(FATAL) << "Invalid step"; + } + return 0; + } +}; +} // namespace std + +#endif // TVM_ANSOR_TRANSFORM_STEP_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index b9a4f25023bf..87e7ad71a7c0 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -19,10 +19,10 @@ #include #include - +#include #include #include -#include "../../src/ansor/compute_dag.h" +#include "../../src/ansor/loop_state.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -52,11 +52,14 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); - int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1); - int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1); + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + + const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, + strides); + CHECK(conv->shape[2].as()->value == OH); + CHECK(conv->shape[3].as()->value == OW); - const auto& conv = topi::conv2d_nchw(data, kernel, strides, padding, - dilation); const auto& bias_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { @@ -82,12 +85,461 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, TEST(ComputeDAG, Basic) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - auto dag = tvm::ansor::ComputeDAGNode::make(tensors); + const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + const auto& state = tvm::ansor::StateNode::make(dag->ops); + CHECK(std::equal_to()(state, dag.GetInitState())); + LOG(INFO) << "\n" << state; LOG(INFO) << "\n" << dag; LOG(INFO) << "\n" << dag->access_analyzer; } +TEST(ComputeDAG, GetProducersConsumers) { + using namespace tvm::ansor; + + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; + int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; + + State s0 = dag.GetInitState(); + std::unordered_set set; + { + std::vector> consumer_list = { + {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, + {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, + {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu} + }; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), 1); + CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = { + {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}} + }; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(set.count(s0->stages[target]->op)); + } + } + } + + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_inline(padding); + { + std::vector> consumer_list = { + {data, conv}, {kernel, conv}, {conv, relu} + }; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), 1); + CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = { + {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}} + }; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(set.count(s0->stages[target]->op)); + } + } + } +} + +TEST(Step, SplitFuseReorder) { + using namespace tvm::ansor; + + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + + State s0 = dag.GetInitState(); + State s1 = s0; + Iterator ti = s0->stages[2]->iters[0]; + Iterator tj = s0->stages[2]->iters[1]; + Iterator tk = s0->stages[2]->iters[2]; + std::vector its; + + CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); + + its = s0.split(2, ti, {16}); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); + CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); + + Iterator tio = its[0], tii = its[1]; + its = s0.split(2, tj, {8}); + CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); + CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); + + Iterator tjo = its[0], tji = its[1]; + s0.reorder(2, {tio, tjo, tk, tji, tii}); + CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); + + s0.fuse(2, {tio, tjo}); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 2048); + + s1.split(2, ti, {8, 2}); + s1.split(2, tj, {32, 8}, false); + CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 32); + CHECK_EQ(s1->stages[2]->iters[1]->range->extent.as()->value, 8); + CHECK_EQ(s1->stages[2]->iters[2]->range->extent.as()->value, 2); + CHECK_EQ(s1->stages[2]->iters[3]->range->extent.as()->value, 32); + CHECK_EQ(s1->stages[2]->iters[4]->range->extent.as()->value, 8); + CHECK_EQ(s1->stages[2]->iters[5]->range->extent.as()->value, 2); +} + +TEST(Step, ComputeAtRootInline) { + using namespace tvm::ansor; + + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + // int data = 0, padding = 1, kernel = 2; + int conv = 3; + // int bias = 4; + int bias_add = 5; + // int bn_scale = 6; + int bn_mul = 7; + // int bn_offset = 8; + int bn_add = 9, relu = 10; + + State s0 = dag.GetInitState(); + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); + const auto& conv_stage_attach = s0->attach_map->stage_to_attach_iter.find(conv); + std::pair iterkey(relu, 2); + CHECK(conv_stage_attach->second == iterkey); + const auto& conv_iter_attach = s0->attach_map->iter_to_attached_stages.find(iterkey); + CHECK_EQ(conv_iter_attach->second.size(), 1); + CHECK_EQ(conv_iter_attach->second[0], conv); + std::stringstream ss; + ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + << "for ax1 (0,3)\n" + << " for ax2 (0,230)\n" + << " for ax3 (0,230)\n" + << " T_pad = ...\n" + << "for ax1 (0,64)\n" + << " for ax2 (0,112)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " for i (None)\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw = ...\n" + << " for ax3 (0,112)\n" + << " T_relu = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); + + s0.compute_root(conv); + s0.compute_root(bn_mul); + CHECK_EQ(s0->attach_map->stage_to_attach_iter.size(), 0); + CHECK_EQ(s0->attach_map->iter_to_attached_stages.size(), 0); + ss.str(std::string()); + ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + << "for ax1 (0,3)\n" + << " for ax2 (0,230)\n" + << " for ax3 (0,230)\n" + << " T_pad = ...\n" + << "for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " for i (None)\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw = ...\n" + << "for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Bn_mul = ...\n" + << "for ax1 (0,64)\n" + << " for ax2 (0,112)\n" + << " for ax3 (0,112)\n" + << " T_relu = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); +} + +TEST(Step, CacheReadWrite) { + using namespace tvm; + using namespace tvm::te; + using namespace tvm::ansor; + + const auto& test_func = []() -> Array { + int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; + int padding = 1; + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel_data = placeholder({CO, CI, KH, KW}, DataType::Float(32), + "kernel_data"); + const auto& k_split = compute(kernel_data->shape, + [&](const Array& i) { + return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, + div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); + }, + "Kernel_split"); + const auto& kernel = compute(kernel_data->shape, + [&](Var i, Var j, Var k, Var l) { + return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; + }, + "Kernel"); + const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, stride, + stride); + const auto& relu = topi::relu(conv); + const auto& out = compute(relu->shape, + [&](Var i, Var j, Var k, Var l) { + return data[i][j][k][l] + relu[i][j][k][l]; + }, + "Add"); + return {data, kernel_data, out}; + }; + const auto& dag = ComputeDAGNode::make(test_func()); + + int data = 0, pad_temp = 1, kernel_data = 2, kernel_split = 3, kernel = 4; + int conv = 5, relu = 6, add = 7; + + // 0: init state + auto s0 = dag.GetInitState(); + std::vector ori_its = s0->stages[add]->iters; + auto its = s0.split(add, s0->stages[add]->iters[0], {2}); + s0.reorder(add, {its[0], ori_its[1], its[1], ori_its[2], ori_its[3]}); + s0.compute_inline(relu); + + // 1: simple cache_write with compute_at + int conv_global = s0.cache_write(conv, "global", dag); + conv++; relu++; add++; + s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); + + // 2: simple cache_read with compute_at + int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); + conv_global++; conv++; relu++; add++; + s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); + std::stringstream ss; + ss << "Placeholder: Data, kernel_data\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,9)\n" + << " for ax3 (0,9)\n" + << " T_pad = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " for ax0_c (None)\n" + << " for ax1_c (None)\n" + << " for ax2_c (None)\n" + << " for ax3_c (None)\n" + << " for i (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Kernel.global = ...\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw.global = ...\n" + << " T_conv2d_nchw = ...\n" + << "for ax0.0 (0,2)\n" + << " for ax1 (0,512)\n" + << " for ax0.1 (0,2)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Add = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); + + // 3: two level cache_read with compute_at + // preparing for GPU's shared memory & local memory + int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); + kernel_data++; kernel_split++; kernel++; kernel_global++; + conv_global++; conv++; relu++; add++; + int pad_temp_shared = s0.cache_read(pad_temp_global, "shared", {conv_global}, + dag); + kernel_data++; kernel_split++; kernel++; kernel_global++; + conv_global++; conv++; relu++; add++; + s0.compute_at(pad_temp_global, conv_global, + s0->stages[conv_global]->iters[2]); + s0.compute_at(pad_temp_shared, conv_global, + s0->stages[conv_global]->iters[4]); + + // 4: cache_read with multi readers + // This stage cannot be compute at to its consumer + s0.cache_read(data, "global", {pad_temp, add}, dag); + pad_temp++; pad_temp_global++; pad_temp_shared++; + kernel_data++; kernel_split++; kernel++; kernel_global++; + conv_global++; conv++; relu++; add++; + ss.str(std::string()); + ss << "Placeholder: Data, kernel_data\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Data.global = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,9)\n" + << " for ax3 (0,9)\n" + << " T_pad = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " for ax0_c (None)\n" + << " for ax1_c (None)\n" + << " for ax2_c (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global = ...\n" + << " for ax3_c (None)\n" + << " for i (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Kernel.global = ...\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global.shared = ...\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw.global = ...\n" + << " T_conv2d_nchw = ...\n" + << "for ax0.0 (0,2)\n" + << " for ax1 (0,512)\n" + << " for ax0.1 (0,2)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Add = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); + + // 5: cache_write with multi outputs + // TVM's cache_write actually has a bug with this case: + + // After schedule.cache_write, TVM generate one new stage: + // From: kernel_data -> kernel_split -> kernel + // To: kernel_data -> kernel_split_global -> kernel_split -> kernel + + // But with topo sort analyse, we get: + // kernel_data -> kernel_split_global -> kernel_split -> kernel + // \ / + // ----------------> kernel_split ----------------> + + // Seems there's bug with the input/output tensor. Such multi outputs case + // should be unusual, so we make some hack on DoCacheWrite + // To be fixed in the future + s0.cache_write(kernel_split, "global", dag); + ss.str(std::string()); + ss << "Placeholder: Data, kernel_data\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Data.global = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,9)\n" + << " for ax3 (0,9)\n" + << " T_pad = ...\n" + << "for ax0_c (0,512)\n" + << " for ax1_c (0,512)\n" + << " for ax2_c (0,3)\n" + << " for ax3_c (0,3)\n" + << " Kernel_split.global = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " for ax0_c (None)\n" + << " for ax1_c (None)\n" + << " for ax2_c (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global = ...\n" + << " for ax3_c (None)\n" + << " for i (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Kernel.global = ...\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global.shared = ...\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw.global = ...\n" + << " T_conv2d_nchw = ...\n" + << "for ax0.0 (0,2)\n" + << " for ax1 (0,512)\n" + << " for ax0.1 (0,2)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Add = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); +} + +TEST(Step, FollowSplitFollowFusedSplit) { + // todo +} + +TEST(Step, Rfactor) { + // todo +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; From f43e82f0ba4353f8fff8fcd830ce08c3bc94c793 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 28 May 2020 15:56:38 +0800 Subject: [PATCH 03/78] Add search_task, measure and serialization (#4) * Add FollowSplit & FollowFusedSplit tests * Update dag.InferBound & its UT * Add search_task, measure and serialization * Update Serialization UT --- include/tvm/ir/expr.h | 5 + include/tvm/runtime/device_api.h | 3 +- src/ansor/compute_dag.cc | 261 ++++++----- src/ansor/measure.cc | 314 +++++++++++++ src/ansor/measure.h | 262 +++++++++++ src/ansor/search_task.cc | 120 +++++ src/ansor/search_task.h | 92 ++++ src/ansor/serialization.cc | 573 ++++++++++++++++++++++++ src/ansor/serialization.h | 78 ++++ src/ansor/utils.h | 7 + src/ir/expr.cc | 2 + src/runtime/cuda/cuda_device_api.cc | 4 + src/runtime/opencl/opencl_device_api.cc | 3 + tests/cpp/ansor_test.cc | 122 ++++- 14 files changed, 1723 insertions(+), 123 deletions(-) create mode 100644 src/ansor/measure.cc create mode 100644 src/ansor/measure.h create mode 100644 src/ansor/search_task.cc create mode 100644 src/ansor/search_task.h create mode 100644 src/ansor/serialization.cc create mode 100644 src/ansor/serialization.h diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b2ce50d91f58..b3e527ca6fd9 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -112,6 +112,11 @@ class PrimExpr : public BaseExpr { * \param value The value to be constructed. */ TVM_DLL PrimExpr(float value); // NOLINT(*) + /*! + * \brief construct from double. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(double value); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { return static_cast(get())->dtype; } diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 421811a52c3b..9b2eb6be2160 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -44,7 +44,8 @@ enum DeviceAttrKind : int { kMaxClockRate = 6, kMultiProcessorCount = 7, kMaxThreadDimensions = 8, - kGcnArch = 9 + kGcnArch = 9, + kMaxRegistersPerBlock = 10 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index e1ae3250d1a5..feaefe9f8e9f 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -3,6 +3,7 @@ */ #include "compute_dag.h" #include +#include #include #include #include @@ -32,7 +33,8 @@ using OperationSet = std::unordered_set; // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops -void TopoSortOps(const Array& tensors, std::vector* ops) { +void TopoSortOps(const Array& tensors, + std::vector* ops) { std::unordered_map degree; std::unordered_map > edge_set; std::unordered_map priority; @@ -193,7 +195,8 @@ bool IsInjective(const te::Operation& op, const std::vector& index, } // Gather all VarNodes in an expr -static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { +static void GatherVars(const PrimExpr& expr, + std::unordered_set* vars) { PostOrderVisit(expr, [&vars](const ObjectRef &node) { if (const VarNode* op = node.as()) { vars->insert(op); @@ -206,7 +209,8 @@ static bool HasExpensiveOp(const PrimExpr& expr) { bool found = false; PostOrderVisit(expr, [&found](const ObjectRef &node) { if (const CallNode* op = node.as()) { - if (op->call_type == CallNode::CallType::PureIntrinsic && op->name == "exp") { + if (op->call_type == CallNode::CallType::PureIntrinsic && + op->name == "exp") { found = true; } } @@ -224,7 +228,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { // build read & write access map for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { - node->read_from[op] = OperationMap > >(); + node->read_from[op] = + OperationMap > >(); } else if (auto cop = op.as()) { TensorAccessExtractor extractor; for (const auto& exp : cop->body) { @@ -232,8 +237,10 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } for (const auto& iter : extractor.buf_accesses) { - std::vector >& accesses = node->read_by[iter.first][op]; - accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); + std::vector >& accesses = + node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), + iter.second.end()); } node->read_from[op] = std::move(extractor.buf_accesses); @@ -251,7 +258,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { node->is_strict_inlineable[op] = false; node->is_output[op] = false; } else if (auto pop = op.as()) { - // check whether is element-wise and strict-inlineable (see definition in compute_dag.h) + // check whether is element-wise and strict-inlineable + // (see definition in compute_dag.h) bool is_injective = true; bool is_strict_inlineable = true; @@ -259,12 +267,14 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { for (const auto& pair : node->read_from[op]) { const std::vector >& access = pair.second; for (const auto& index : access) { - if (!IsInjective(op, index, &axis_missing, &axis_duplicated, &same_order)) { + if (!IsInjective(op, index, &axis_missing, &axis_duplicated, + &same_order)) { is_injective = false; is_strict_inlineable = false; break; } - if (!same_order || axis_duplicated) { // do not strictly inline transpose + if (!same_order || axis_duplicated) { + // do not strictly inline transpose is_strict_inlineable = false; } } @@ -281,9 +291,11 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } node->is_injective[op] = is_injective; - node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + node->is_strict_inlineable[op] = is_strict_inlineable && + !has_expensive_op; - // check whether the op needs multi-level tiling (see definition in compute_dag.h) + // check whether the op needs multi-level tiling + // (see definition in compute_dag.h) bool needs_multi_level_tiling = false; int n_missing = 0; @@ -297,7 +309,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } bool missing = false; for (const auto& axis : pop->axis) { - if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { + if (GetIntImm(axis->dom->extent) > 1 && + vars.count(axis->var.get()) == 0) { missing = true; } } @@ -928,89 +941,90 @@ std::pair > ComputeDAG::ApplySteps( } } -// std::string ComputeDAG::PrintStepsAsPython( -// const std::vector& transform_steps) const { -// std::vector stages; -// StageToAxesMap stage_to_axes; -// Array ops; -// for (const auto& op : operator->()->ops) { -// if (!op->IsInstance()) { -// ops.push_back(op); -// } -// } -// te::Schedule schedule = te::create_schedule({ops.back()}); +std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_steps) const { + std::vector stages; + StageToAxesMap stage_to_axes; + Array ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + te::Schedule schedule = te::create_schedule({ops.back()}); -// // init axes -// for (const auto& x : operator->()->ops) { -// const te::Stage& stage = schedule.operator[](x); -// stages.push_back(stage); -// UpdateStageAxis(stage, &stage_to_axes); -// } + // init axes + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages.push_back(stage); + UpdateStageAxis(stage, &stage_to_axes); + } -// std::stringstream ss; + std::stringstream ss; -// for (const auto& stage : stages) { -// if (stage->op->IsInstance()) { -// for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { -// ss << stage->leaf_iter_vars[i]->var->name_hint; -// if (i != stage->leaf_iter_vars.size() - 1) { -// ss << ", "; -// } -// } -// ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" -// << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; -// } -// } + for (const auto& stage : stages) { + if (stage->op->IsInstance()) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + ss << stage->leaf_iter_vars[i]->var->name_hint; + if (i != stage->leaf_iter_vars.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" + << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; + } + } -// for (const auto& step : transform_steps) { -// ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); -// } + for (const auto& step : transform_steps) { + ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, + transform_steps); + } -// return ss.str(); -// } + return ss.str(); +} -// State ComputeDAG::ReplayAndInferBound(const std::vector& transform_steps) const { -// State ret_state = GetInitState(); -// StateNode* pstate = ret_state.CopyOnWrite(); -// pstate->transform_steps = transform_steps; -// ret_state.DoSteps(transform_steps, *this); +State ComputeDAG::ReplayAndInferBound( + const std::vector& transform_steps) const { + State ret_state = GetInitState(); + StateNode* pstate = ret_state.CopyOnWrite(); + pstate->transform_steps = transform_steps; + ret_state.DoSteps(transform_steps, *this); -// InferBoundCommon(pstate); + InferBoundCommon(pstate); -// return ret_state; -// } + return ret_state; +} -// State ComputeDAG::InferBound(const State& state) const { -// State ret_state = state; -// StateNode* pstate = ret_state.CopyOnWrite(); +State ComputeDAG::InferBound(const State& state) const { + State ret_state = state; + StateNode* pstate = ret_state.CopyOnWrite(); -// InferBoundCommon(pstate); + InferBoundCommon(pstate); -// return ret_state; -// } + return ret_state; +} -// void ComputeDAG::InferBound(std::vector* states) const { -// std::vector out_states(states->size(), State()); +void ComputeDAG::InferBound(std::vector* states) const { + std::vector out_states(states->size(), State()); -// auto worker_func = [&states, &out_states, this](int idx) { -// try { -// out_states[idx] = this->InferBound((*states)[idx]); -// } catch (dmlc::Error &e) { -// LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] -// << "\n" << e.what() << std::endl; -// } -// }; + auto worker_func = [&states, &out_states, this](int idx) { + try { + out_states[idx] = this->InferBound((*states)[idx]); + } catch (dmlc::Error &e) { + LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] + << "\n" << e.what() << std::endl; + } + }; -// // Lower states in parallel -// ThreadPool& pool = ThreadPool::Global(); -// pool.BeginBatch(states->size()); -// for (size_t i = 0; i < states->size(); ++i) { -// pool.Enqueue(worker_func, i); -// } -// pool.WaitBatch(); + // Lower states in parallel + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(states->size()); + for (size_t i = 0; i < states->size(); ++i) { + pool.Enqueue(worker_func, i); + } + pool.WaitBatch(); -// *states = std::move(out_states); -// } + *states = std::move(out_states); +} void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, ComputeDAG *task_dag) const { @@ -1019,7 +1033,8 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, te::Schedule sch; Array old_tensors; - std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, &stage_to_axes); + std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, + &stage_to_axes); Array new_tensors; for (auto stage : sch->stages) { @@ -1035,45 +1050,47 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, } -// void ComputeDAG::InferBoundCommon(StateNode* pstate) const { -// std::vector stages; -// StageToAxesMap stage_to_axes; -// te::Schedule sch; -// Array tensors; -// Map bounds; +void ComputeDAG::InferBoundCommon(StateNode* pstate) const { + std::vector stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array tensors; + Map bounds; -// std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); -// sch = sch.normalize(); -// bounds = schedule::InferBound(sch); + std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, + &stage_to_axes); + sch = sch.normalize(); + bounds = te::InferBound(sch); -// for (size_t i = 0; i < pstate->stages.size(); ++i) { -// const Stage& stage = pstate->stages[i]; + for (size_t i = 0; i < pstate->stages.size(); ++i) { + const Stage& stage = pstate->stages[i]; -// if (stage->compute_at == kInlined) { -// continue; -// } + if (stage->compute_at == kInlined) { + continue; + } -// std::vector new_iters; -// new_iters.reserve(stage->iters.size()); -// for (size_t j = 0; j < stage->iters.size(); ++j) { -// const Iterator& iter = stage->iters[j]; -// const IterVar& axis = stage_to_axes.at(stages[i])[j]; - -// auto find_res = bounds.find(axis); -// if (find_res != bounds.end()) { -// new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, -// iter->iter_type, iter->annotation, -// &iter->ori_iters)); -// } else { -// LOG(FATAL) << "Infer bound fails"; -// } -// } + std::vector new_iters; + new_iters.reserve(stage->iters.size()); + for (size_t j = 0; j < stage->iters.size(); ++j) { + const Iterator& iter = stage->iters[j]; + const IterVar& axis = stage_to_axes.at(stages[i])[j]; + + auto find_res = bounds.find(axis); + if (find_res != bounds.end()) { + new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, + iter->iter_type, + iter->annotation, + &iter->ori_iters)); + } else { + LOG(FATAL) << "Infer bound fails"; + } + } -// pstate->stages[i] = StageNode::make(stage->op, stage->op_type, -// std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, -// stage->storage_offset); -// } -// } + pstate->stages[i] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); + } +} std::pair > ComputeDAG::ReplaySteps( const std::vector &transform_steps, @@ -1096,8 +1113,8 @@ std::pair > ComputeDAG::ReplaySteps( UpdateStageAxis(stage, stage_to_axes); } - // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at - // an splitted axis? + // todo(lmzheng): should we maintain the attach_map and keep the validity of + // compute_at an splitted axis? // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); @@ -1183,8 +1200,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } else if (combiner->IsInstance()) { const auto& select = combiner.as(); ss << " select(" << select->condition << ", " << select->true_value - << ", " << select->false_value << ")= " - << '(' << preduce->source[0] << ',' << preduce->source[1] << ")\n"; + << ", " << select->false_value << ")= " << '(' + << preduce->source[0] << ',' << preduce->source[1] << ")\n"; } else { LOG(FATAL) << "Unsupported reduction operator" << combiner; } @@ -1208,7 +1225,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "is_injective:\t" << node->is_injective.at(op) << "\t\t"; p->stream << "needs_multi_level_tiling:\t" << node->needs_multi_level_tiling.at(op) << std::endl; - p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) << "\t"; + p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) + << "\t"; p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; p->stream << "Read from:\t"; for (const auto& pair : node->read_from.at(op)) { @@ -1233,7 +1251,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { if (i == j) { continue; } - if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { + if (ana.ElementWiseMatch(node->ops_topo_order[i], + node->ops_topo_order[j])) { p->stream << node->ops_topo_order[i]->func_name() << " -> " << node->ops_topo_order[j]->func_name() << "\n"; } diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc new file mode 100644 index 000000000000..1bae02b3f2c5 --- /dev/null +++ b/src/ansor/measure.cc @@ -0,0 +1,314 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "measure.h" +// #include +#include +#include +#include +#include +#include +#include +// #include "search_policy/search_policy.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(MeasureInputNode); +TVM_REGISTER_NODE_TYPE(BuildResultNode); +TVM_REGISTER_NODE_TYPE(MeasureResultNode); +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_OBJECT_TYPE(RunnerNode); +TVM_REGISTER_OBJECT_TYPE(BuilderNode); +TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); +TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); +TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); +TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); + +const char *ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", +}; + +// Maker +MeasureInput MeasureInputNode::make(SearchTask task, State state) { + auto node = make_object(); + node->task = std::move(task); + node->state = std::move(state); + return MeasureInput(node); +} + +MeasureInput MeasureInputNode::copy() const { + auto node = make_object(); + node->task = task; + node->state = state; + return MeasureInput(node); +} + +BuildResult BuildResultNode::make(std::string filename, Array args, int error_no, + std::string error_msg, double time_cost) { + auto node = make_object(); + node->filename = std::move(filename); + node->args = std::move(args); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->time_cost = time_cost; + return BuildResult(node); +} + +MeasureResult MeasureResultNode::make(Array costs, int error_no, + std::string error_msg, double all_cost, double timestamp) { + auto node = make_object(); + node->costs = std::move(costs); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->all_cost = all_cost; + node->timestamp = timestamp; + return MeasureResult(node); +} + +MeasureResult MeasureResultNode::copy() const { + auto node = make_object(); + node->costs = costs; + node->error_no = error_no; + node->error_msg = error_msg; + node->all_cost = all_cost; + node->timestamp = timestamp; + return MeasureResult(node); +} + +Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { + auto node = make_object(); + node->timeout = timeout; + node->n_parallel = n_parallel; + node->build_func = build_func; + return Builder(node); +} + +// LocalBuilder and LocalRunner +Array LocalBuilderNode::Build(const Array &inputs, int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { + Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); + return results; + } else { + LOG(FATAL) << "ansor.local_builder.build is not registered"; + } + return Array(); +} + +Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { + auto node = make_object(); + node->key = key; + node->host = host; + node->port = port; + node->priority = priority; + node->timeout = timeout; + node->n_parallel = n_parallel; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + return Runner(node); +} + +Array RPCRunnerNode::Run(const Array& inputs, + const Array& build_results, + int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { + Array results = (*f)(inputs, build_results, key, host, port, priority, + timeout, n_parallel, number, repeat, + min_repeat_ms, cooldown_interval, verbose); + return results; + } else { + LOG(FATAL) << "ansor.rpc_runner.run is not registered"; + } + return Array(); +} + +Runner LocalRunnerNode::make(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { + ObjectPtr node = make_object(); + node->timeout = timeout; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + return Runner(node); +} + +Array LocalRunnerNode::Run(const Array& inputs, + const Array& build_results, + int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { + Array results = (*f)(inputs, build_results, timeout, number, + repeat, min_repeat_ms, cooldown_interval, verbose); + return results; + } else { + LOG(FATAL) << "ansor.local_runner.run is not registered"; + } + return Array(); +} + +ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, + Array callbacks, + int verbose, + int max_continous_error) { + auto node = make_object(); + node->builder = std::move(builder); + node->runner = std::move(runner); + node->callbacks = std::move(callbacks); + node->verbose = verbose; + node->max_continous_error = max_continous_error < 0 ? + DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + return ProgramMeasurer(node); +} + +void ProgramMeasurerNode::Reset() { + ct = error_ct = 0; + best_flops.clear(); + best_ct.clear(); + best_state.clear(); +} + +void ProgramMeasurerNode::Measure(const SearchTask& task, + const SearchPolicy& policy, + const std::vector& inputs, + std::vector* results, + int batch_size) { + results->clear(); + results->reserve(inputs.size()); + + if (batch_size == -1) { + // set default batch size + batch_size = builder->n_parallel * 2; + } + + StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" + << std::endl; + + for (size_t i = 0; i < inputs.size(); i += batch_size) { + std::vector input_batch(inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector result_batch; + + // build and run + SilentMeasure(task, input_batch, &result_batch); + + // update current best state according to the new measure result + for (size_t j = 0; j < input_batch.size(); ++j) { + double flops; + if (result_batch[j]->error_no == 0) { + flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + error_ct = 0; + } else { + flops = 0.0; + error_ct++; + } + + const std::string& workload_key = input_batch[j]->task->workload_key; + if (flops > best_flops[workload_key]) { + best_flops[workload_key] = flops; + best_state[workload_key] = input_batch[j]->state; + best_ct[workload_key] = ct; + } + + ct++; + if (verbose >= 1) { + std::cout << std::fixed << std::setprecision(2); + std::cout << "===============================================\n"; + std::cout << "No: " << ct + << "\tGFLOPS: " << flops / 1e9 << " / " << best_flops[workload_key] / 1e9 + << "\tresults: " << result_batch[j] << "\n"; + std::cout << "===============================================\n"; + std::cout << input_batch[j]->state << "\n"; + } + } + + // Call callback functions + for (const auto& callback : callbacks) { + callback->callback(policy, input_batch, result_batch); + } + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } + + if (error_ct > max_continous_error) { + LOG(FATAL) << "Too many errors happened during tuning"; + } + } +} + +void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, + const std::vector& inputs, + std::vector* results) { + // Close the thread pool to avoid the conflits with python environment + ThreadPool::Global().Abort(); + + results->clear(); + results->reserve(inputs.size()); + Array input_batch(inputs.begin(), inputs.end()); + + // Call builder and runner + Array build_res_batch = builder->Build(input_batch, verbose); + Array result_batch = runner->Run(input_batch, build_res_batch, verbose); + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } +} + +// Printing functions +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + p->stream << "MeasureInput()"; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/measure.h b/src/ansor/measure.h new file mode 100644 index 000000000000..4ea1562315ff --- /dev/null +++ b/src/ansor/measure.h @@ -0,0 +1,262 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs + */ + +#ifndef TVM_ANSOR_MEASURE_H_ +#define TVM_ANSOR_MEASURE_H_ + +// #include +#include +#include +#include +#include +#include "search_task.h" +#include "loop_state.h" + +namespace tvm { +namespace ansor { + +class SearchPolicy; +class MeasureInput; class BuildResult; class MeasureResult; +class Builder; class Runner; class MeasureCallback; class ProgramMeasurer; + +extern const char *ErrorNoToStr[]; + +enum MeasureErrorNO { + kNoError = 0, // No error + kInstantiationError = 1, // Errors happen when apply transform steps from init state + kCompileHostError = 2, // Errors happen when compiling code on host (when build module) + kCompileDeviceError = 3, // Errors happen when compiling code on device (when load module) + kRuntimeDeviceError = 4, // Errors happen when run program on device + kWrongAnswerError = 5, // Answer is wrong when compared to a reference output + kBuildTimeoutError = 6, // Timeout during compilation + kRunTimeoutError = 7, // Timeout during run + kUnknonwError = 8, // Unknown error +}; + +// Inputs and results of one measurement + +/* \brief Store the input of a meansurement */ +class MeasureInputNode: public Object { + public: + SearchTask task; + State state; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("task", &task); + v->Visit("state", &state); + } + + static MeasureInput make(SearchTask task, State state); + MeasureInput copy() const; // Do deep copy + + static constexpr const char* _type_key = "ansor.MeasureInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); +}; +TVM_DEFINE_NODE_REF(MeasureInput, MeasureInputNode); + +/* \brief Store the input of a build */ +class BuildResultNode: public Object { + public: + std::string filename; + Array args; + int error_no; + std::string error_msg; + double time_cost; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("filename", &filename); + v->Visit("args", &args); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("time_cost", &time_cost); + } + + static BuildResult make(std::string filename, Array args, + int error_no, std::string error_msg, double time_cost); + + static constexpr const char* _type_key = "ansor.BuildResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); +}; +TVM_DEFINE_NODE_REF(BuildResult, BuildResultNode); + +/* \brief Store the results of a measurement */ +class MeasureResultNode: public Object { + public: + Array costs; + int error_no; + std::string error_msg; + double all_cost; + double timestamp; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("costs", &costs); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("all_cost", &all_cost); + v->Visit("timestamp", ×tamp); + } + + MeasureResult copy() const; // Do deep copy + + static MeasureResult make(Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp); + + static constexpr const char* _type_key = "ansor.MeasureResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); +}; +TVM_DEFINE_NODE_REF(MeasureResult, MeasureResultNode); + + +// Measure callback +class MeasureCallbackNode: public Object { + public: + virtual void callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) = 0; + static constexpr const char *_type_key = "ansor.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(MeasureCallback, MeasureCallbackNode); + + +// Base class for builder and runner + +/* \brief Builder that builds the programs */ +class BuilderNode: public Object { + public: + int n_parallel; + int timeout; + + virtual Array Build(const Array& inputs, int verbose) = 0; + + static constexpr const char* _type_key = "ansor.Builder"; + TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Builder, BuilderNode); + +/* \brief Runner that runs the built programs and measure the time cost */ +class RunnerNode: public Object { + public: + int timeout; + + virtual Array Run(const Array& inputs, + const Array& build_results, + int verbose) = 0; + + static constexpr const char* _type_key = "ansor.Runner"; + TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Runner, RunnerNode); + + +// Implementation of various builders and runners +/* \brief LocalBuilder use local CPU cores to build programs in parallel */ +class LocalBuilderNode: public BuilderNode { + public: + std::string build_func; + + static Builder make(int timeout, int n_parallel, const std::string& build_func); + + Array Build(const Array& inputs, int verbose) final; + + static constexpr const char* _type_key = "ansor.LocalBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); +}; + +class RPCRunnerNode : public RunnerNode { + public: + std::string key; + std::string host; + int port; + int priority; + int n_parallel; + int number; + int repeat; + int min_repeat_ms; + double cooldown_interval; + + static Runner make(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval); + + Array Run(const Array& inputs, + const Array& build_results, + int verbose) final; + + static constexpr const char* _type_key = "ansor.RPCRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); +}; + +/* \brief LocalRunner use local CPU/GPU to runs programs in serial and measure the time cost */ +class LocalRunnerNode: public RunnerNode { + public: + int number; + int repeat; + int min_repeat_ms; + double cooldown_interval; + + static Runner make(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval); + + Array Run(const Array& inputs, + const Array& build_results, + int verbose) final; + + static constexpr const char* _type_key = "ansor.LocalRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); +}; + + +/*! + * \brief Measurer measures the time costs of tvm programs + * This class combines Builder and Runner, and provides a simpler API + */ +class ProgramMeasurerNode: public Object { + public: + static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; + + int ct; + int error_ct; // continuous error counter + std::unordered_map best_flops; + std::unordered_map best_state; + std::unordered_map best_ct; + + Builder builder; + Runner runner; + Array callbacks; + int verbose; + int max_continous_error; + + static ProgramMeasurer make(Builder builder, Runner runner, + Array callbacks, + int verbose, + int max_continous_error = -1); + + /*! \brief Reset book keeping variables */ + void Reset(); + + /*! \biref Do measurement */ + void Measure(const SearchTask& task, + const SearchPolicy& policy, + const std::vector& inputs, + std::vector* results, + int batch_size = -1); + + /*! \biref Do measurement silently */ + void SilentMeasure(const SearchTask& task, + const std::vector& inputs, + std::vector* results); + + static constexpr const char* _type_key = "ansor.ProgramMeasurer"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(ProgramMeasurer, ProgramMeasurerNode); + + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_MEASURE_H_ diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc new file mode 100644 index 000000000000..b9cda9168b9e --- /dev/null +++ b/src/ansor/search_task.cc @@ -0,0 +1,120 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "search_task.h" +#include +#include +#include +#include +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(HardwareParamsNode); +TVM_REGISTER_OBJECT_TYPE(SearchTaskNode); + +HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + auto node = make_object(); + node->num_cores = num_cores; + node->vector_unit_bytes = vector_unit_bytes; + node->cache_line_bytes = cache_line_bytes; + node->max_unroll_vec = max_unroll_vec; + node->max_innermost_split_factor = max_innermost_split_factor; + return HardwareParams(node); +} + +HardwareParams HardwareParamsNode::GetDefaultHardwareParams( + const Target& target, const Target& target_host) { + if (target->target_name == "llvm") { + return HardwareParamsNode::make(tvm::runtime::threading::MaxConcurrency(), + 32, 64, 16, 64); + } else if (target->device_type == kDLGPU) { + // TODO(jcf94): temp implementation, max vectorize size in GPU is related + // to the data type + auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto* p_hardware_params = hardware_params.CopyOnWrite(); + + auto ctx = TVMContext{kDLGPU, 0}; + auto func = tvm::runtime::Registry::Get("device_api.gpu"); + CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, + &ret); + p_hardware_params->max_shared_memory_per_block = ret; + + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, + &ret); + p_hardware_params->max_registers_per_block = ret; + + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + &ret); + p_hardware_params->max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + p_hardware_params->warp_size = ret; + + // Manually set now + p_hardware_params->max_vthread_extent = 4; + + return hardware_params; + } else if (target->device_type == kDLOpenCL) { + // TODO(jcf94): temp implementation + auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto p_hardware_params = hardware_params.CopyOnWrite(); + + auto ctx = TVMContext{kDLOpenCL, 0}; + auto func = tvm::runtime::Registry::Get("device_api.opencl"); + CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, + &ret); + p_hardware_params->max_shared_memory_per_block = ret; + + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + &ret); + p_hardware_params->max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + p_hardware_params->warp_size = ret; + + // Manually set now + p_hardware_params->max_vthread_extent = 4; + + return hardware_params; + } else { + LOG(FATAL) << "No default hardware parameters for target: " << target; + } + return HardwareParams(); +} + + +SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, HardwareParams hardware_params) { + auto node = make_object(); + node->compute_dag = std::move(compute_dag); + node->workload_key = std::move(workload_key); + node->target = std::move(target); + node->target_host = std::move(target_host); + if (hardware_params.defined()) { + node->hardware_params = std::move(hardware_params); + } else { + node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams( + node->target, node->target_host); + } + return SearchTask(node); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h new file mode 100644 index 000000000000..7db98a5197a5 --- /dev/null +++ b/src/ansor/search_task.h @@ -0,0 +1,92 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Meta information for a search task + */ + +#ifndef TVM_ANSOR_SEARCH_TASK_H_ +#define TVM_ANSOR_SEARCH_TASK_H_ + +#include +#include +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +class HardwareParams; class SearchTask; + +/*! \brief Hardware related parameters */ +class HardwareParamsNode : public Object { + public: + int num_cores; + int vector_unit_bytes; + int cache_line_bytes; + // The max length of the axis to be unrolled or vectorized + int max_unroll_vec; + // The max split factor for the innermost tile + int max_innermost_split_factor; + + // Limit params for GPU schedule + int max_shared_memory_per_block{INT32_MAX}; + int max_registers_per_block{INT32_MAX}; + int max_threads_per_block{INT32_MAX}; + int max_vthread_extent{INT32_MAX}; + int warp_size{INT32_MAX}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_cores", &num_cores); + v->Visit("vector_unit_bytes", &vector_unit_bytes); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("max_unroll_vec", &max_unroll_vec); + v->Visit("max_innermost_split_factor", &max_innermost_split_factor); + + v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); + v->Visit("max_registers_per_block", &max_registers_per_block); + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("max_vthread_extent", &max_vthread_extent); + v->Visit("warp_size", &warp_size); + } + + static HardwareParams make(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor); + static HardwareParams GetDefaultHardwareParams(const Target& target, + const Target& target_host); + + static constexpr const char *_type_key = "ansor.HardwareParams"; + TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); + + +/*! \brief Meta-info for a search task */ +class SearchTaskNode : public Object { + public: + ComputeDAG compute_dag; + std::string workload_key; + Target target; + Target target_host; + HardwareParams hardware_params; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("compute_dag", &compute_dag); + v->Visit("workload_key", &workload_key); + v->Visit("target", &target); + v->Visit("target_host", &target_host); + v->Visit("hardware_params", &hardware_params); + } + + static SearchTask make(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params); + + static constexpr const char *_type_key = "ansor.SearchTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_TASK_H_ diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc new file mode 100644 index 000000000000..0e2b0be42587 --- /dev/null +++ b/src/ansor/serialization.cc @@ -0,0 +1,573 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include +// #include +#include +#include +#include +#include +#include +#include +#include "serialization.h" +#include "loop_state.h" +#include "utils.h" + +// Json serialization handler for MeasureInput, MeasureResult +// (and recursively SearchTask, State, Step, ... +namespace dmlc { +namespace json { + +inline std::vector& FloatArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { + out->clear(); + for (const auto&x : data) { + auto pf = x.as<::tvm::tir::FloatImmNode>(); + CHECK(pf != nullptr) << "Cost can only contain float values"; + out->push_back(pf->value); + } + return *out; +} + +inline std::vector& IntArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { + out->clear(); + for (const auto&x : data) { + auto pi = x.as<::tvm::tir::IntImmNode>(); + CHECK(pi != nullptr) << "Cost can only contain int values"; + out->push_back(pi->value); + } + return *out; +} + +template <> +struct Handler > { + inline static void Write(dmlc::JSONWriter* writer, + const std::vector<::tvm::ansor::Stage> & data) { + // todo(lmzheng): support serialization of Stage + writer->BeginArray(false); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + std::vector<::tvm::ansor::Stage> * data) { + bool s; + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler > { + inline static void Write(dmlc::JSONWriter* writer, + const std::vector<::tvm::ansor::Step> & data) { + std::vector tmp; + writer->BeginArray(false); + for (size_t i = 0; i < data.size(); ++i) { + writer->WriteArraySeperator(); + writer->BeginArray(false); + if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { + writer->WriteArrayItem(std::string("RS")); + writer->WriteArrayItem(ps->stage_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->after_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { + writer->WriteArrayItem(std::string("SS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + if (ps->extent.defined()) { + writer->WriteArrayItem(::tvm::ansor::GetIntImm(ps->extent)); + } else { + writer->WriteArrayItem(0); + } + writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); + writer->WriteArrayItem(static_cast(ps->inner_to_outer)); + } else if (auto ps = data[i].as<::tvm::ansor::FollowSplitStepNode>()) { + writer->WriteArrayItem(std::string("FSS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->src_step_id); + writer->WriteArrayItem(ps->n_split); + } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { + writer->WriteArrayItem(std::string("FFSS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->src_step_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + + writer->WriteArrayItem(ps->level); + writer->WriteArrayItem(static_cast(ps->factor_or_nparts)); + } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { + writer->WriteArrayItem(std::string("FS")); + writer->WriteArrayItem(ps->stage_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->fused_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + } else if (auto ps = data[i].as<::tvm::ansor::AnnotationStepNode>()) { + writer->WriteArrayItem(std::string("AS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(static_cast(ps->annotation)); + } else if (auto ps = data[i].as<::tvm::ansor::ComputeAtStepNode>()) { + writer->WriteArrayItem(std::string("CA")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->target_stage_id); + writer->WriteArrayItem(ps->target_iter_id); + } else if (auto ps = data[i].as<::tvm::ansor::ComputeRootStepNode>()) { + writer->WriteArrayItem(std::string("CR")); + writer->WriteArrayItem(ps->stage_id); + } else if (auto ps = data[i].as<::tvm::ansor::ComputeInlineStepNode>()) { + writer->WriteArrayItem(std::string("CI")); + writer->WriteArrayItem(ps->stage_id); + } else if (auto ps = data[i].as<::tvm::ansor::CacheReadStepNode>()) { + writer->WriteArrayItem(std::string("CHR")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->scope_name); + writer->WriteArrayItem(ps->reader_stage_ids); + } else if (auto ps = data[i].as<::tvm::ansor::CacheWriteStepNode>()) { + writer->WriteArrayItem(std::string("CHW")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->scope_name); + } else if (auto ps = data[i].as<::tvm::ansor::PragmaStepNode>()) { + writer->WriteArrayItem(std::string("PS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->pragma_type); + } else if (auto ps = data[i].as<::tvm::ansor::RfactorStepNode>()) { + writer->WriteArrayItem(std::string("RFS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->factor_iter_id); + } else if (auto ps = data[i].as<::tvm::ansor::StorageAlignStepNode>()) { + writer->WriteArrayItem(std::string("SA")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->factor); + writer->WriteArrayItem(ps->offset); + } else { + LOG(FATAL) << "Invalid step: " << data[i]; + } + writer->EndArray(); + } + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + std::vector<::tvm::ansor::Step> * data) { + std::vector int_list; + bool s, inner_to_outer, factor_or_nparts; + std::string name, scope_name, pragma_type; + int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; + int level, factor_iter_id, factor, offset; + + reader->BeginArray(); + data->clear(); + while (reader->NextArrayItem()) { + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&name); + if (name == "RS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + } else if (name == "SS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&extent); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&inner_to_outer); + data->push_back(::tvm::ansor::SplitStepNode::make( + stage_id, iter_id, extent, + std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), + inner_to_outer)); + } else if (name == "FSS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&src_step_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&n_split); + data->push_back(::tvm::ansor::FollowSplitStepNode::make( + stage_id, iter_id, src_step_id, n_split)); + } else if (name == "FFSS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&level); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&factor_or_nparts); + data->push_back(::tvm::ansor::FollowFusedSplitStepNode::make( + stage_id, iter_id, int_list, level, factor_or_nparts)); + } else if (name == "FS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::FuseStepNode::make(stage_id, int_list)); + } else if (name == "AS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&ann); + data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, + iter_id, ::tvm::ansor::IteratorAnnotation(ann))); + } else if (name == "CA") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&target_stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + data->push_back(::tvm::ansor::ComputeAtStepNode::make( + stage_id, target_stage_id, iter_id)); + } else if (name == "CR") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + data->push_back(::tvm::ansor::ComputeRootStepNode::make(stage_id)); + } else if (name == "CI") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + data->push_back(::tvm::ansor::ComputeInlineStepNode::make(stage_id)); + } else if (name == "CHR") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&scope_name); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::CacheReadStepNode::make( + stage_id, scope_name, int_list)); + } else if (name == "CHW") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&scope_name); + data->push_back(::tvm::ansor::CacheWriteStepNode::make( + stage_id, scope_name)); + } else if (name == "PS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&pragma_type); + data->push_back(::tvm::ansor::PragmaStepNode::make( + stage_id, iter_id, pragma_type)); + } else if (name == "RFS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&factor_iter_id); + data->push_back(::tvm::ansor::RfactorStepNode::make( + stage_id, iter_id, factor_iter_id)); + } else if (name == "SA") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&factor); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&offset); + data->push_back(::tvm::ansor::StorageAlignStepNode::make( + stage_id, iter_id, factor, offset)); + } else { + LOG(FATAL) << "Invalid step format"; + } + s = reader->NextArrayItem(); CHECK(!s); + } + } +}; + +template <> +struct Handler<::tvm::ansor::StateNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::StateNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.stages); + writer->WriteArrayItem(data.transform_steps); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::StateNode* data) { + reader->BeginArray(); + bool s; + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->stages); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->transform_steps); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::ansor::SearchTaskNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::SearchTaskNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.workload_key); + writer->WriteArrayItem(data.target->str()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::SearchTaskNode* data) { + std::string target_str; + bool s; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->workload_key); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&target_str); + data->target = ::tvm::Target::Create(target_str); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::ansor::MeasureInputNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::MeasureInputNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(*data.task.operator->()); + writer->WriteArrayItem(*data.state.operator->()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::MeasureInputNode* data) { + bool s; + auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>(); + auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>(); + state_node->complete = true; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(task_node.get()); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(state_node.get()); + s = reader->NextArrayItem(); CHECK(!s); + + data->task = ::tvm::ansor::SearchTask(task_node); + data->state = ::tvm::ansor::State(state_node); + } +}; + +template <> +struct Handler<::tvm::ansor::MeasureResultNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::MeasureResultNode& data) { + writer->BeginArray(false); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto&x : data.costs) { + auto pf = x.as<::tvm::tir::FloatImmNode>(); + CHECK(pf != nullptr) << "Cost can only contain float values"; + writer->WriteArrayItem(pf->value); + } + writer->EndArray(); + writer->WriteArrayItem(data.error_no); + writer->WriteArrayItem(data.all_cost); + writer->WriteArrayItem(static_cast((data.timestamp))); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::MeasureResultNode* data) { + bool s; + std::vector tmp; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&tmp); + data->costs.clear(); + for (const auto& i : tmp) { + data->costs.push_back(i); + } + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->error_no); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->all_cost); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->timestamp); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +} // namespace json +} // namespace dmlc + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(LogToFileNode); +TVM_REGISTER_OBJECT_TYPE(LogReaderNode); + +const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) + +MeasureCallback LogToFileNode::make(std::string filename) { + auto node = make_object(); + node->filename = std::move(filename); + return MeasureCallback(node); +} + +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, + const Array& results) { + dmlc::JSONWriter writer(os); + for (size_t i = 0; i < inputs.size(); ++i) { + writer.BeginObject(false); + writer.WriteObjectKeyValue("i", *inputs[i].operator->()); + writer.WriteObjectKeyValue("r", *results[i].operator->()); + writer.WriteObjectKeyValue("v", ansor_LOG_VERSION); + writer.EndObject(); + *os << "\n"; + } +} + +void ReadMeasureRecords(std::string str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version) { + std::istringstream ss(str); + dmlc::JSONReader reader(&ss); + std::string key; + + reader.BeginObject(); + while (reader.NextObjectItem(&key)) { + if (key == "i") { + reader.Read(inp); + } else if (key == "r") { + reader.Read(res); + } else if (key == "v") { + reader.Read(log_version); + } else { + LOG(FATAL) << "Invalid key in json log: " << key; + } + } +} + +TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); + +void LogToFileNode::callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) { + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, inputs, results); +} + +LogReader LogReaderNode::make(std::string filename) { + auto node = make_object(); + node->filename = filename; + node->infile.open(filename, std::ifstream::in); + return LogReader(node); +} + +bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { + std::string log_version; + + while (std::getline(infile, cur_line)) { + if (cur_line[0] == '#' || cur_line[0] == ' ') { + // skip comment lines begin with '#' or ' ' + continue; + } + + try { + ReadMeasureRecords(cur_line, inp, res, &log_version); + } catch (...) { + return false; + } + + return true; + } + + return false; +} + +std::pair, Array > LogReaderNode::ReadLines( + int max_size, int skip_size) { + auto inp = make_object(); + auto res = make_object(); + Array inputs; + Array results; + + while (ReadNext(inp.get(), res.get())) { + if (skip_size > 0) { + skip_size--; + continue; + } + + inputs.push_back(inp->copy()); + results.push_back(res->copy()); + + if (max_size > 0 && static_cast(inputs.size()) >= max_size) { + break; + } + } + + return std::make_pair(inputs, results); +} + +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target) { + std::pair best_pair; + double best_cost = 1e30; + + auto inp = make_object(); + auto res = make_object(); + LogReader reader = LogReaderNode::make(filename); + + while (reader->ReadNext(inp.get(), res.get())) { + if (res->error_no != kNoError || inp->task->workload_key != workload_key + || inp->task->target->target_name != target->target_name) { + continue; + } + + double cost = FloatArrayMean(res->costs); + + if (cost < best_cost) { + best_cost = cost; + best_pair = std::make_pair(inp->copy(), res->copy()); + } + } + + return best_pair; +} + +} // namespace ansor +} // namespace tvm \ No newline at end of file diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h new file mode 100644 index 000000000000..96dfb0ee320b --- /dev/null +++ b/src/ansor/serialization.h @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/serialization.h + * \brief Json serialization format for dumping and loading tuning records + */ + +#ifndef TVM_ANSOR_SERIALIZATION_H_ +#define TVM_ANSOR_SERIALIZATION_H_ + +#include +#include +#include +#include "measure.h" +// #include "search_policy/search_policy.h" + +namespace tvm { +namespace ansor { + +class LogReader; + +/*! \brief Log the input and results of measurments to file */ +class LogToFileNode: public MeasureCallbackNode { + public: + std::string filename; + + static MeasureCallback make(std::string filename); + + /*! \brief Log measure pairs to file. This is called by the search policy */ + void callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) final; + + static constexpr const char *_type_key = "ansor.LogToFile"; + TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); +}; + +/*! \brief Log reader */ +class LogReaderNode: public Object { + public: + std::string filename; + std::ifstream infile; + + static LogReader make(std::string filename); + + /*! \brief Read next line in the log file + * \return Whether the read is successful */ + bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); + + /*! \brief Read multiple lines from the log file + * \param max_size The maximum number of lines. -1 means read all lines + * \param skip_size Skip the first n lines */ + std::pair, Array > ReadLines( + int max_size = -1, int skip_size = 0); + + static constexpr const char* _type_key = "ansor.LogReader"; + TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); + private: + std::string cur_line; +}; +TVM_DEFINE_MUTABLE_NODE_REF(LogReader, LogReaderNode); + +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, + const Array& results); + +void ReadMeasureRecords(std::string str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version); + +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SERIALIZATION_H_ diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 4ea7f283ad09..67ebb836c680 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -61,6 +61,13 @@ struct hash > { namespace tvm { namespace ansor { +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ObjectRef { \ + public: \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + }; \ + /*! \brief Macro to make it easy to define mutable node ref type given node */ #define TVM_DEFINE_MUTABLE_NODE_REF(TypeName, NodeName) \ class TypeName : public ObjectRef { \ diff --git a/src/ir/expr.cc b/src/ir/expr.cc index fd380aa33f86..6e898dd5ddb4 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -38,6 +38,8 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(double value) : PrimExpr(FloatImm(DataType::Float(64), value)) {} + PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; if (auto* ptr = ref.as()) { diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index a6d4a5499469..4e71383cc1bb 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -94,6 +94,10 @@ class CUDADeviceAPI final : public DeviceAPI { } case kGcnArch: return; + case kMaxRegistersPerBlock: { + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id)); + break; + } } *rv = value; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 6d9835e6231c..71d3232ca4d5 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -109,6 +109,9 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* } case kGcnArch: return; + default: { + LOG(WARNING) << "Attr not implemented."; + } } } diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 87e7ad71a7c0..c43ec5c0a751 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -23,6 +23,7 @@ #include #include #include "../../src/ansor/loop_state.h" +#include "../../src/ansor/serialization.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -157,6 +158,63 @@ TEST(ComputeDAG, GetProducersConsumers) { } } +TEST(ComputeDAG, InferBoundSerialization) { + using namespace tvm::ansor; + + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + int A = 0, B = 1, C = 2; + + State s0 = dag.GetInitState(); + int C_global = s0.cache_write(C, "global", dag); + C++; + const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); + const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); + s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its0[3], its1[3]}); + s0.compute_at(C_global, C, s0->stages[C]->iters[3]); + s0.split(C_global, s0->stages[C_global]->iters[2], {16}); + int B_global = s0.cache_read(B, "global", {C_global}, dag); + C++; C_global++; + s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); + int A_global = s0.cache_read(A, "global", {C_global}, dag); + B++; B_global++; C++; C_global++; + s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); + + const auto& s1 = dag.InferBound(s0); + std::vector s2 = {s0}; + dag.InferBound(&s2); + const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); + + CHECK_EQ(s1->stages[B_global]->iters[0]->range->extent.as()->value, + 512); + CHECK_EQ(s1->stages[B_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ(s1->stages[A_global]->iters[0]->range->extent.as()->value, + 1); + CHECK_EQ(s1->stages[A_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ(s1->stages[C_global]->iters[0]->range->extent.as()->value, + 64); + CHECK(std::equal_to()(s1, s2[0])); + CHECK(std::equal_to()(s1, s3)); + + const auto& minp0 = MeasureInputNode::make( + SearchTaskNode::make(dag, "test", tvm::target::llvm(), + tvm::target::llvm(), + HardwareParams()), + s0); + const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); + std::stringstream ss; + WriteMeasureRecords(&ss, {minp0}, {mres0}); + auto minp1 = tvm::make_object(); + auto mres1 = tvm::make_object(); + std::string log_version; + ReadMeasureRecords(ss.str(), minp1.get(), mres1.get(), &log_version); + const auto& s4 = dag.ReplayAndInferBound(minp1->state->transform_steps); + CHECK(std::equal_to()(s1, s4)); +} + TEST(Step, SplitFuseReorder) { using namespace tvm::ansor; @@ -533,7 +591,69 @@ TEST(Step, CacheReadWrite) { } TEST(Step, FollowSplitFollowFusedSplit) { - // todo + using namespace tvm::ansor; + + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + + State s0 = dag.GetInitState(); + int C = 2; + + int C_global = s0.cache_write(C, "global", dag); + C++; + + // FollowSplitStep currently only support `inner_to_outer = true` + const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); + int split_step0 = s0->transform_steps.size() - 1; + // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, false); + // int split_step1 = s0->transform_steps.size() - 1; + for (int level = 1; level <= 5; level++) { + State tmp = s0; + tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, + level); + // tmp.follow_split(C_global, s0->stages[C_global]->iters[5], split_step1, + // level); + const auto& stage_C = tmp->stages[C]; + const auto& stage_C_global = tmp->stages[C_global]; + for (int i = 0; i < level; i++) { + CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, + stage_C_global->iters[i]->range->extent.as()->value); + } + // for (int i = 0; i < level; i++) { + // CHECK(stage_C->iters[i+5]->range->extent.as()->value == + // stage_C_global->iters[i+5]->range->extent.as()->value); + // } + } + + const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {2, 2, 4, 8}); + int split_step1 = s0->transform_steps.size() - 1; + std::vector its; + for (int i = 0; i < 5; i++) { + its.push_back(its0[i]); + its.push_back(its1[i]); + } + s0.reorder(C, its); + for (int i = 0; i < 5; i++) { + s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i+1]}); + } + for (int level = 0; level < 4; level++) { + State tmp = s0; + tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], + {split_step0, split_step1}, level, false); + const auto& stage_C = tmp->stages[C]; + const auto& stage_C_global = tmp->stages[C_global]; + CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, + stage_C_global->iters[0]->range->extent.as()->value); + } + for (int level = 0; level < 4; level++) { + State tmp = s0; + tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], + {split_step0, split_step1}, level, true); + const auto& stage_C = tmp->stages[C]; + const auto& stage_C_global = tmp->stages[C_global]; + CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, + stage_C_global->iters[1]->range->extent.as()->value); + } } TEST(Step, Rfactor) { From e0a5ed58b1f9e8296f1a6e9fb269a3426037cbf1 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 29 May 2020 18:46:01 +0800 Subject: [PATCH 04/78] Add MetaTileRewritePolicy (#5) * Add feature * Add cost_model, meta_tile_rewrite_policy * Add MetaTileRewritePolicy basic UT --- src/ansor/cost_model/cost_model.cc | 163 ++ src/ansor/cost_model/cost_model.h | 98 ++ src/ansor/feature.cc | 1386 ++++++++++++++++ src/ansor/feature.h | 63 + .../search_policy/meta_tile_rewrite_policy.cc | 1420 +++++++++++++++++ .../search_policy/meta_tile_rewrite_policy.h | 101 ++ src/ansor/search_policy/search_policy.cc | 14 + src/ansor/search_policy/search_policy.h | 53 + src/ansor/search_policy/utils.cc | 609 +++++++ src/ansor/search_policy/utils.h | 428 +++++ tests/cpp/ansor_test.cc | 99 +- 11 files changed, 4420 insertions(+), 14 deletions(-) create mode 100644 src/ansor/cost_model/cost_model.cc create mode 100644 src/ansor/cost_model/cost_model.h create mode 100644 src/ansor/feature.cc create mode 100644 src/ansor/feature.h create mode 100644 src/ansor/search_policy/meta_tile_rewrite_policy.cc create mode 100644 src/ansor/search_policy/meta_tile_rewrite_policy.h create mode 100644 src/ansor/search_policy/search_policy.cc create mode 100644 src/ansor/search_policy/search_policy.h create mode 100644 src/ansor/search_policy/utils.cc create mode 100644 src/ansor/search_policy/utils.h diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc new file mode 100644 index 000000000000..d4304bccb4bf --- /dev/null +++ b/src/ansor/cost_model/cost_model.cc @@ -0,0 +1,163 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "cost_model.h" +#include +#include +#include + +namespace tvm { +namespace ansor { + +using ::tvm::runtime::NDArray; + +TVM_REGISTER_OBJECT_TYPE(CostModelNode); +TVM_REGISTER_OBJECT_TYPE(RandomModelNode); +TVM_REGISTER_OBJECT_TYPE(MeasureModelNode); +TVM_REGISTER_OBJECT_TYPE(PythonBasedCostModelNode); + +void RandomNumber(TVMArgs args, TVMRetValue* rv) { + int n = args[0]; + void* data = args[1]; + float* fdata = reinterpret_cast(data); + for (int i = 0; i < n; i++) { + fdata[i] = static_cast(rand_r(0)) / (static_cast(RAND_MAX)); + } +} + +CostModel RandomModelNode::make() { + ObjectPtr node = make_object(); + node->random_number_func = + runtime::Registry::Get("ansor.cost_model.random_number"); + if (node->random_number_func == nullptr) { + LOG(WARNING) << "ansor.cost_model.random_number is not registered, " + << "use C++ default random_number func instead."; + static PackedFunc cost_model_random_number(RandomNumber); + node->random_number_func = &cost_model_random_number; + } + return CostModel(node); +} + +void RandomModelNode::Update(const Array& inputs, + const Array& results) { +} + +void RandomModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { + scores->resize(states.size()); + (*random_number_func)(states.size(), static_cast(scores->data())); +} + +CostModel MeasureModelNode::make(Builder builder, Runner runner) { + ObjectPtr node = make_object(); + node->measurer = ProgramMeasurerNode::make(std::move(builder), std::move(runner), + Array(), 0); + return CostModel(node); +} + +void MeasureModelNode::Update(const Array& inputs, + const Array& results) { +} + +void MeasureModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { + std::vector inputs; + std::vector results; + + inputs.clear(); inputs.reserve(states.size()); + for (const auto& state : states) { + inputs.push_back(MeasureInputNode::make(task, state)); + } + measurer->SilentMeasure(task, inputs, &results); + + scores->clear(); + scores->reserve(results.size()); + for (const auto& res : results) { + scores->push_back(1.0 / FloatArrayMean(res->costs)); + } +} + +CostModel PythonBasedCostModelNode::make(PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func) { + auto node = make_object(); + node->update_func = std::move(update_func); + node->predict_func = std::move(predict_func); + node->predict_stage_func = std::move(predict_stage_func); + return CostModel(node); +} + +void PythonBasedCostModelNode::Update(const Array& inputs, + const Array& results) { + update_func(inputs, results); +} + +void PythonBasedCostModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { + scores->resize(states.size()); + predict_func(task, Array(states.begin(), states.end()), + static_cast(scores->data())); +} + +void PythonBasedCostModelNode::PredictStages(const SearchTask& task, + const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { + int n_states = states.size(); + int n_stages = task->compute_dag.GetInitState()->stages.size(); + std::vector flatten_scores; + flatten_scores.resize(n_states * n_stages * 2); // Allocate sufficient spaces. + predict_stage_func(task, Array(states.begin(), states.end()), + static_cast(flatten_scores.data())); + + // Unpack flatten scores. + state_scores->clear(); + stage_scores->clear(); + + // Score of each states. + for (int i = 0; i < n_states; ++i) { + state_scores->push_back(flatten_scores[i]); + } + + // Score of each stage in each states. + size_t idx = n_states; + for (int i = 0; i < n_states; ++i) { + CHECK_LE(idx, flatten_scores.size()); + + // Number of scored stages of this state. + int s_length = (int)flatten_scores[idx++]; + + if (s_length > 0) { + std::vector scores; + int offset = 0; + + if ((*state_scores)[i] > -INFINITY) { + // If the score is valid. Copy scored stages and assign 0 to placeholder and inlined stages. + // If the score is 0, meaning this state failed to be lowered. Just bypass to update offset. + for (const Stage& stage : states[i]->stages) { + if (stage->op_type == kPlaceholder) { + scores.push_back(0); + continue; + } + if (stage->compute_at == kInlined) { + scores.push_back(0); + continue; + } + scores.push_back(flatten_scores[idx + offset]); + offset++; + } + CHECK_EQ(offset, s_length); + stage_scores->push_back(std::move(scores)); + } + idx += s_length; + } else { + // Cost model does not provide any stage score details. + stage_scores->push_back({}); + } + } +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h new file mode 100644 index 000000000000..36179573c617 --- /dev/null +++ b/src/ansor/cost_model/cost_model.h @@ -0,0 +1,98 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/cost_model.h + * \brief Base class of cost model + */ + +#ifndef TVM_ANSOR_COST_MODEL_COST_MODEL_H_ +#define TVM_ANSOR_COST_MODEL_COST_MODEL_H_ + +#include +#include +#include +#include +#include "../measure.h" + +namespace tvm { +namespace ansor { + +using runtime::PackedFunc; + +class CostModel; + +/*! \brief The base class for cost model */ +class CostModelNode: public Object { + public: + virtual void Update(const Array& inputs, const Array& results) = 0; + virtual void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) = 0; + virtual void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) = 0; + + static constexpr const char *_type_key = "ansor.CostModel"; + TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(CostModel, CostModelNode); + +/*! \brief The cost model returns random value for all predictions */ +class RandomModelNode: public CostModelNode { + public: + const PackedFunc* random_number_func; + + static CostModel make(); + + void Update(const Array& inputs, const Array& results) final; + void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) final; + void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { ; } + + static constexpr const char *_type_key = "ansor.RandomModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); +}; + +class MeasureModelNode : public CostModelNode { + public: + ProgramMeasurer measurer; + + static CostModel make(Builder builder, Runner runner); + + void Update(const Array& inputs, const Array& results) final; + void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) final; + void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { ; } + + static constexpr const char* _type_key = "ansor.MeasureModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); +}; + +/*! \brief A wrapper for cost model defined by python code + * This class will call python's function */ +class PythonBasedCostModelNode: public CostModelNode { + public: + PackedFunc update_func; + PackedFunc predict_func; + PackedFunc predict_stage_func; + + static CostModel make(PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func); + + void Update(const Array& inputs, const Array& results) final; + void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) final; + void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) final; + + static constexpr const char *_type_key = "ansor.PythonBasedCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedCostModelNode, CostModelNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_COST_MODEL_COST_MODEL_H_ diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc new file mode 100644 index 000000000000..cb865bc3b5ae --- /dev/null +++ b/src/ansor/feature.cc @@ -0,0 +1,1386 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "measure.h" +#include "serialization.h" +#include "utils.h" +// #include "../arithmetic/compute_expr.h" + +namespace tvm { +/* Import the function from build_module.cc */ +extern void GetBinds(const Array& args, + bool compact, + const std::unordered_map& binds, + Map* out_binds, + Array* out_arg_list, + const BuildConfig& config); +} // namespace tvm + + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; +using arith::ConstIntBound; +using arith::Analyzer; + +static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; + +// Annotation position encoding +enum AnnotationPosType { + kPosNone, kPosInnerSpatial, kPosMiddleSpatial, kPosOuterSpatial, + kPosInnerReduce, kPosMiddleReduce, kPosOuterReduce, kPosMixed +}; + +// Buffer access type +enum BufferAccessType { + kRead, kWrite, kReadWrite, kUnknownRW +}; + +// Accesses to a buffer +struct BufferAccess { + BufferAccessType acc_type{kUnknownRW}; + std::vector > indices; +}; + +// Data reuse type +enum ReuseType { + kLoopMultipleRead, kSerialMultipleReadWrite, kNoReuse +}; + +// Feature for an access of a buffer +struct BufferAccessFeature { + std::string tensor_name; + BufferAccessType acc_type; + float bytes; + float unique_bytes; + float lines; + float unique_lines; + ReuseType reuse_type; + float reuse_dis_iter; // reuse distance in iterator number + float reuse_dis_bytes; // reuse distance in total touched bytes + float reuse_ct; // reuse times + float bytes_d_reuse_ct; + float unique_bytes_d_reuse_ct; + float lines_d_reuse_ct; + float unique_lines_d_reuse_ct; + float stride; +}; + +// Feature set of a statement +struct FeatureSet { + // compute feature + float float_mad; + float float_addsub; + float float_mul; + float float_divmod; + float float_cmp; + float float_math_func; + float float_other_func; + float int_mad; + float int_addsub; + float int_mul; + float int_divmod; + float int_cmp; + float int_math_func; + float int_other_func; + float bool_op; + float select_op; + float vec_num; // The number of vectorized iterators + float vec_prod; // The product of the lengths of vectorized iterators + float vec_len; // The length of the innermost vectorized iterator + AnnotationPosType vec_type; + float unroll_num; // The number of unrolled iterators + float unroll_prod; // The product of the lengths of vectorized iterators + float unroll_len; // The length of the innermost unrolled iterator + AnnotationPosType unroll_type; + float parallel_num; // The number of paralleled iterators + float parallel_prod; // The product of the lengths of paralleled iterators + float parallel_len; // The length of the innermost paralleled iterators + AnnotationPosType parallel_type; + float is_gpu; + float blockIdx_x_len; + float blockIdx_y_len; + float blockIdx_z_len; + float threadIdx_x_len; + float threadIdx_y_len; + float threadIdx_z_len; + float vthread_len; + + float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; + + // buffer access feature (per buffer) + std::vector access_feas; + + // allocation feature + float alloc_size; + float alloc_prod; + float alloc_outer_prod; + float alloc_inner_prod; + + // overall feature + float outer_prod; + float num_loops; + float auto_unroll_max_step; +}; + +// Return whether a var is in an expr +bool VarInExpr(const Var& var, const PrimExpr& expr) { + bool find = false; + + PostOrderVisit(expr, [&find, &var](const ObjectRef &node) { + if (find) { + return; + } + + if (const VarNode* op = node.as()) { + if (op == var.get()) { + find = true; + } + } + }); + + return find; +} + +// Get position encoding for annotation +AnnotationPosType GetAnnotationPosEncoding( + const Var& var, const Array& spatial_args, + const Array& axis, const Array& reduce_axis) { + // Try to match spatial args first + size_t find_i = 0; + size_t find_ct = 0; + for (size_t i = 0; i < spatial_args.size(); ++i) { + if (VarInExpr(var, spatial_args[i])) { + find_i = i; + find_ct += 1; + } + } + + if (find_ct == 0) { + // If not find in spatial args, then it is a reduce iteartor. + // Use name to match + for (size_t i = 0; i < reduce_axis.size(); ++i) { + if (var->name_hint.find(reduce_axis[i]->var->name_hint) != std::string::npos) { + find_i = i; + find_ct++; + } + } + if (find_ct >= 1) { + if (find_i == 0) { + return kPosInnerReduce; + } else if (find_i == reduce_axis.size() - 1) { + return kPosOuterReduce; + } else { + return kPosMiddleReduce; + } + } else { + // If the axis is not found in both spatial args and reduce axis, + // then this stage must compute_at somewhere under this aixs and this axis is simplified out + // We assume it is an outer spatial + return kPosOuterSpatial; + } + } else if (find_ct == 1) { + if (find_i == spatial_args.size() - 1) { + return kPosInnerSpatial; + } else if (find_i == 0) { + return kPosOuterSpatial; + } else { + return kPosMiddleSpatial; + } + } else { + return kPosMixed; + } +} + +// Count math ops in an expr +class MathOpCounter : public StmtExprVisitor { + public: +#define VisitBinary(Type, float_ct, int_ct) \ + void VisitExpr_(const Type* op) final { \ + if (op->a.dtype().is_float()) { \ + float_ct++; \ + } else { \ + int_ct++; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ + } \ + + VisitBinary(AddNode, float_addsub, int_addsub); + VisitBinary(SubNode, float_addsub, int_addsub); + VisitBinary(MulNode, float_mul, int_mul); + VisitBinary(DivNode, float_divmod, int_divmod); + VisitBinary(ModNode, float_divmod, int_divmod); + VisitBinary(FloorDivNode, float_divmod, int_divmod); + VisitBinary(FloorModNode, float_divmod, int_divmod); + VisitBinary(MaxNode, float_cmp, int_cmp); + VisitBinary(MinNode, float_cmp, int_cmp); + VisitBinary(EQNode, float_cmp, int_cmp); + VisitBinary(NENode, float_cmp, int_cmp); + VisitBinary(LTNode, float_cmp, int_cmp); + VisitBinary(LENode, float_cmp, int_cmp); + VisitBinary(GTNode, float_cmp, int_cmp); + VisitBinary(GENode, float_cmp, int_cmp); + + void VisitExpr_(const AndNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const OrNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const NotNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const SelectNode* op) final { select_op++; StmtExprVisitor::VisitExpr_(op); } + + // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode + void VisitExpr_(const CallNode* op) final { + if (op->call_type == CallNode::CallType::PureIntrinsic) { + if (op->dtype.is_float()) { + float_math_func++; + } else { + int_math_func++; + } + } else if (op->call_type != CallNode::CallType::Halide) { + if (op->dtype.is_float()) { + float_other_func++; + } else { + int_other_func++; + } + } + StmtExprVisitor::VisitExpr_(op); + } + + // todo(lmzheng): detect mad + size_t float_mad{0}, float_addsub{0}, float_mul{0}, float_divmod{0}, + float_cmp{0}, float_math_func{0}, float_other_func{0}; + size_t int_mad{0}, int_addsub{0}, int_mul{0}, int_divmod{0}, + int_cmp{0}, int_math_func{0}, int_other_func{0}; + size_t bool_op{0}, select_op{0}; +}; + + +// Extract all buffer accesses in an expr +class BufferAccessExtractor : public StmtExprVisitor { + public: + void ExtractReads(const PrimExpr& expr) { + this->VisitExpr(expr); + } + + void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, const Array& indices) { + BufferAccess& acc = buf_accesses[ten]; + acc.acc_type = acc_type; + acc.indices.push_back(std::vector(indices.begin(), indices.end())); + } + + // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode + void VisitExpr_(const CallNode *op) final { + if (op->call_type == CallNode::CallType::Halide) { + te::Tensor ten = Downcast(op->func).output(op->value_index); + BufferAccess& acc = buf_accesses[ten]; + switch (acc.acc_type) { + case kRead: + break; + case kWrite: + acc.acc_type = kReadWrite; break; + case kReadWrite: + break; + case kUnknownRW: + default: + acc.acc_type = kRead; break; + } + + if (acc.acc_type != kReadWrite) { + // If a buffer is both read and written, in the tvm DSL, it must be a update, + // so the indices should be the same. Then we can skip appending indices for it. + // Otherwise we do the following. + buf_accesses[ten].indices.push_back( + std::vector(op->args.begin(), op->args.end())); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + std::unordered_map buf_accesses; +}; + +// Compute coefficient for an loop iterator in an expression +// Note: we use a approximation strategy to find coefficient. +// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) +class CoefficientExtractor : public StmtExprVisitor { + public: + void VisitExpr_(const MulNode *node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_add) { + if (auto a = node->a.as()) { + visited_mul = true; + stride = a->value; + } else if (auto b = node->b.as()) { + visited_mul = true; + stride = b->value; + } + } + } + } + + void VisitExpr_(const AddNode *node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_mul) { + visited_add = true; + stride = 1; + } + } + } + + void VisitExpr_(const VarNode *node) final { + if (node == var_) { + visited_var = true; + // This is a magic default stride in case our approximation strategy fails + stride = 2; + } + } + + int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { + visited_var = visited_mul = visited_add = false; + var_ = var; + + this->VisitExpr(expr); + + if (visited_var && !visited_mul && !visited_add) { + return 1; + } else { + return stride; + } + } + + bool visited_var{false}; + bool visited_mul{false}; + bool visited_add{false}; + int stride{0}; + + private: + const VarNode* var_{nullptr}; +}; + +// Compute stride for the accesses to a buffer +int64_t ComputeStride(const std::vector >& indices, + const std::vector& shape, + const VarNode* stride_var) { + int64_t min_stride = std::numeric_limits::max(); + bool find = false; + CoefficientExtractor extractor; + + for (const auto &index : indices) { + int64_t shape_stride = 1; + for (int i = static_cast(index.size()) - 1; i >= 0; i--) { + int coefficient = extractor.ExtractCoefficient(index[i], stride_var); + if (extractor.visited_var) { + find = true; + min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); + break; + } + shape_stride *= shape[i]; + } + } + + return find ? min_stride : 0; +} + +// Compute touched bytes and cache lines for accesses to a buffer +void ComputeRegion( + const std::vector > &indices, + arith::Analyzer* ana, + std::vector* region) { + region->clear(); + + if (indices.empty()) { + return; + } + + region->reserve(indices[0].size()); + + if (indices.size() == 1) { + for (const auto& index : indices[0]) { + ConstIntBound bound = ana->const_int_bound(index); + region->push_back(bound->max_value - bound->min_value + 1); + } + } else { + // future(lmzheng): implement a more accurate IntSet? + for (size_t i = 0; i < indices[0].size(); ++i) { + int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; + for (size_t j = 0; j < indices.size(); ++j) { + ConstIntBound bound = ana->const_int_bound(indices[j][i]); + + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + region->push_back(maximum - minimum + 1); + } + } +} + +// Compute reuse distance and reuse ratio for accesses to a buffer +// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct +std::tuple ComputeReuse( + const te::Tensor& t, + const std::vector >& indices, + const std::vector& for_loop_stack, + const std::unordered_map > > >& for_touch_regions) { + float reuse_dis_iter = 1.0f; + float reuse_dis_bytes = -1.0f; + + for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; --i) { + const ForNode* cur_for = for_loop_stack[i]; + bool find = false; + + for (size_t j = 0; j < indices.size(); j++) { + for (size_t k = 0; k < indices[j].size(); k++) { + if (VarInExpr(cur_for->loop_var, indices[j][k])) { + find = true; + break; + } + } + if (find) { + break; + } + } + + int64_t extent = GetIntImm(for_loop_stack[i]->extent); + if (find) { + // accumulate/update reuse distance + reuse_dis_iter *= extent; + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); + } + } + } else { + // Have LoopMultipleRead reuse + if (reuse_dis_bytes < 0) { + // For the reuse in the innermost axis, the above code won't be executed. + // So we compute bytes here + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += 1 * std::get<2>(access); + } + } + } + return std::make_tuple(kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); + } + + const std::unordered_map > >& + tensor_map = for_touch_regions.at(cur_for); + + int serial_reuse = static_cast(tensor_map.at(t).size()) - 1; + if (serial_reuse > 0) { + int64_t extent = GetIntImm(cur_for->extent); + + // Have SerialMultipleReadWrite reuse + reuse_dis_iter = std::numeric_limits::max(); + for (const auto& acc_info : tensor_map.at(t)) { + reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); + } + + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); + } + } + + return std::make_tuple(kSerialMultipleReadWrite, + reuse_dis_iter / extent, reuse_dis_bytes / extent, serial_reuse); + } + } + + return std::make_tuple(kNoReuse, 0, 0, 0); +} + +// Extract features for every Provide statement +class PerStmtFeatureExtractor : public StmtExprVisitor { + public: + explicit PerStmtFeatureExtractor(int cache_line_size) : + cache_line_size_(cache_line_size) {} + + void VisitStmt_(const AttrStmtNode* node) final { + if (node->attr_key == tir::attr::thread_extent || + node->attr_key == tir::attr::virtual_thread) { + const Var& var = node->node.as()->var; + int extent = GetIntImm(node->value); + + int* plen = nullptr; + + const std::string& name = var.get()->name_hint; + if (node->attr_key == tir::attr::thread_extent) { + if (name == "blockIdx.x") { + plen = &blockIdx_x_len; + } else if (name == "blockIdx.y") { + plen = &blockIdx_y_len; + } else if (name == "blockIdx.z") { + plen = &blockIdx_z_len; + } else if (name == "threadIdx.x") { + plen = &threadIdx_x_len; + } else if (name == "threadIdx.y") { + plen = &threadIdx_y_len; + } else if (name == "threadIdx.z") { + plen = &threadIdx_z_len; + } else { + LOG(FATAL) << "invalid thread itervar " + name; + } + } else { + plen = &vthread_len; + } + + int extent_before = *plen; + if (node->attr_key == tir::attr::thread_extent) { + *plen = extent; + } else { + *plen *= extent; + } + + is_gpu = true; + + // make a fake for node for blockIdx.x or threadIdx.x + Stmt fake_for_node = ForNode::make(var, 0, extent, ForType::Parallel, + DeviceAPI::None, node->body); + + outer_loop_prod *= extent; + for_loop_stack.push_back(fake_for_node.as()); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack.pop_back(); + outer_loop_prod /= extent; + + *plen = extent_before; + } else if (node->attr_key == "pragma_auto_unroll_max_step") { + int value = GetIntImm(node->value); + + int16_t old_value = cur_auto_unroll_max_step; + cur_auto_unroll_max_step = value; + StmtExprVisitor::VisitStmt_(node); + cur_auto_unroll_max_step = old_value; + } else { + StmtExprVisitor::VisitStmt_(node); + } + } + + void VisitStmt_(const ForNode* node) final { + int64_t loop_extent = GetIntImm(node->extent); + + if (node->for_type == ForType::Vectorized) { + vec_for_stack.push_back(node); + } else if (node->for_type == ForType::Unrolled) { + unroll_for_stack.push_back(node); + } else if (node->for_type == ForType::Parallel) { + parallel_for_stack.push_back(node); + } + + outer_loop_prod *= loop_extent; + for_loop_stack.push_back(node); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack.pop_back(); + outer_loop_prod /= loop_extent; + + if (node->for_type == ForType::Vectorized) { + vec_for_stack.pop_back(); + } else if (node->for_type == ForType::Unrolled) { + unroll_for_stack.pop_back(); + } else if (node->for_type == ForType::Parallel) { + parallel_for_stack.pop_back(); + } + } + + // TODO(...): ProvideNode is deprecated, move to BufferStoreNode + void VisitStmt_(const ProvideNode* node) final { + te::Operation op = Downcast(node->func); + te::Tensor ten = op.output(node->value_index); + const te::ComputeOpNode* pcompute = op.as(); + + FeatureSet &fea = op_features[ten]; + + // compute feature + MathOpCounter mathops; + mathops(node->value); + fea.float_mad = outer_loop_prod * mathops.float_mad; + fea.float_addsub = outer_loop_prod * mathops.float_addsub; + fea.float_mul = outer_loop_prod * mathops.float_mul; + fea.float_divmod = outer_loop_prod * mathops.float_divmod; + fea.float_cmp = outer_loop_prod * mathops.float_cmp; + fea.float_math_func = outer_loop_prod * mathops.float_math_func; + fea.float_other_func = outer_loop_prod * mathops.float_other_func; + fea.int_mad = outer_loop_prod * mathops.int_mad; + fea.int_addsub = outer_loop_prod * mathops.int_addsub; + fea.int_mul = outer_loop_prod * mathops.int_mul; + fea.int_divmod = outer_loop_prod * mathops.int_divmod; + fea.int_math_func = outer_loop_prod * mathops.int_math_func; + fea.int_cmp = outer_loop_prod * mathops.int_cmp; + fea.int_other_func = outer_loop_prod * mathops.int_other_func; + fea.bool_op = outer_loop_prod * mathops.bool_op; + fea.select_op = outer_loop_prod * mathops.select_op; + + fea.outer_prod = outer_loop_prod; + fea.num_loops = for_loop_stack.size(); + fea.auto_unroll_max_step = cur_auto_unroll_max_step; + fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; + fea.vec_type = fea.unroll_type = fea.parallel_type = kPosNone; + + fea.vec_num = vec_for_stack.size(); + if (!vec_for_stack.empty()) { + fea.vec_len = GetIntImm(vec_for_stack.back()->extent); + fea.vec_prod = 1.0; + for (const ForNode* pfor : vec_for_stack) { + fea.vec_prod *= GetIntImm(pfor->extent); + } + fea.vec_type = GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, + node->args, pcompute->axis, pcompute->reduce_axis); + } + + fea.unroll_num = unroll_for_stack.size(); + if (!unroll_for_stack.empty()) { + fea.unroll_len = GetIntImm(unroll_for_stack.back()->extent); + fea.unroll_prod = 1.0; + for (const ForNode* pfor : unroll_for_stack) { + fea.unroll_prod *= GetIntImm(pfor->extent); + } + fea.unroll_type = GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, + node->args, pcompute->axis, pcompute->reduce_axis); + } + + fea.parallel_num = parallel_for_stack.size(); + if (!parallel_for_stack.empty()) { + fea.parallel_len = GetIntImm(parallel_for_stack.back()->extent); + fea.parallel_prod = 1.0; + for (const ForNode* pfor : parallel_for_stack) { + fea.parallel_prod *= GetIntImm(pfor->extent); + } + fea.parallel_type = GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, + node->args, pcompute->axis, pcompute->reduce_axis); + } + + // GPU threads + fea.is_gpu = is_gpu; + fea.blockIdx_x_len = blockIdx_x_len; + fea.blockIdx_y_len = blockIdx_y_len; + fea.blockIdx_z_len = blockIdx_z_len; + fea.threadIdx_x_len = threadIdx_x_len; + fea.threadIdx_y_len = threadIdx_y_len; + fea.threadIdx_z_len = threadIdx_z_len; + fea.vthread_len = vthread_len; + + // Extract all buffer access + std::vector acc_feas; + BufferAccessExtractor buf_extractor; + buf_extractor.InsertAccess(ten, kWrite, node->args); + buf_extractor.ExtractReads(node->value); + + // Compute touched region for all outer loops + Analyzer ana; + for (auto x : for_loop_stack) { + ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1)); + } + + std::vector mem_bytes_list; + std::vector compute_ops_list; + + mem_bytes_list.reserve(for_loop_stack.size()); + compute_ops_list.reserve(for_loop_stack.size()); + + int cur_compute_ops = mathops.float_mad + mathops.float_addsub + mathops.float_mul + + mathops.float_divmod + mathops.float_cmp + + mathops.float_math_func + mathops.float_other_func; + + std::vector tmp_region; + for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { + const ForNode* p_for = for_loop_stack[i]; + + ana.Bind(p_for->loop_var, + Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent)); + + // Note, here we do overwrite. + // So if there are multiple Provides, the last one will overwrite the first few. + // e.g. The update part in gemm will overwrite the init part. + std::unordered_map > >& + tensor_regions_map = for_touch_regions[p_for]; + + int64_t mem_bytes = 0; + for (const auto &x : buf_extractor.buf_accesses) { + const te::Tensor& t = x.first; + const BufferAccess& acc = x.second; + + ComputeRegion(acc.indices, &ana, &tmp_region); + int64_t touched_size = ElementProduct(tmp_region); + tensor_regions_map[t].push_back(std::make_tuple(acc.acc_type, + touched_size, t->dtype.bytes())); + mem_bytes += touched_size * t->dtype.bytes(); + } + + mem_bytes_list.push_back(std::log2(mem_bytes)); + cur_compute_ops *= GetIntImm(for_loop_stack[i]->extent); + compute_ops_list.push_back(std::log2(cur_compute_ops)); + } + + // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). + // We use piecewise linear interpolation to fit this curve. + int pt = 0; + if (cur_compute_ops <= 0 || compute_ops_list.empty()) { + std::fill(fea.arith_intensity_curve, + fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); + } else { + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + float cur_compute_ops = compute_ops_list.back() * (i+1) / ARITH_INTENSITY_CURVE_SAMPLE_N; + while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { + pt++; + } + CHECK_LT(pt, compute_ops_list.size()); + + float value; + if (pt == 0) { + value = compute_ops_list[pt] / mem_bytes_list[pt]; + } else { + float base = compute_ops_list[pt-1] / mem_bytes_list[pt-1]; + float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - + compute_ops_list[pt-1] / mem_bytes_list[pt-1]) / + (compute_ops_list[pt] - compute_ops_list[pt-1]); + value = base + slope * (cur_compute_ops - compute_ops_list[pt-1]); + } + fea.arith_intensity_curve[i] = value; + } + } + + // Compute buffer access feature + for (const auto &x : buf_extractor.buf_accesses) { + const te::Tensor& t = x.first; + const BufferAccess& acc = x.second; + + std::vector int_shape; + for (const auto& dim : t->shape) { + int_shape.push_back(GetIntImm(dim)); + } + + size_t ele_bytes = t->dtype.bytes(); + + // calculate bytes + float bytes = outer_loop_prod * ele_bytes; + float unique_bytes; + + // calculate cache lines + int64_t stride; + float lines; + float unique_lines; + + if (for_loop_stack.empty()) { + unique_bytes = ele_bytes; + stride = 0; + lines = 1.0f; + unique_lines = 1.0f; + } else { + unique_bytes = std::get<1>(for_touch_regions[for_loop_stack.front()][t].front()) + * ele_bytes; + + stride = 0; + int64_t reduce_ratio = 1; + + int i; + for (i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { + stride = ComputeStride(acc.indices, int_shape, for_loop_stack[i]->loop_var.get()); + if (stride != 0) { + break; + } + reduce_ratio *= GetIntImm(for_loop_stack.back()->extent); + } + + lines = outer_loop_prod / reduce_ratio * + std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); + lines = std::max(lines, 1.0f); + + // convert `stride` back to the stride of the innermost iterator + stride = (i == static_cast(for_loop_stack.size()) - 1 ? stride : 0); + + float n_continuous = ele_bytes; + for (int i = static_cast(tmp_region.size()) - 1; i >= 0; i--) { + if (tmp_region[i] == int_shape[i]) { + n_continuous *= tmp_region[i]; + break; + } + } + unique_lines = unique_bytes / std::min(n_continuous, + static_cast(cache_line_size_)); + unique_lines = std::max(unique_lines, 1.0f); + } + + ReuseType reuse_type; + float reuse_dis_iter, reuse_dis_bytes, reuse_ct; + std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = + ComputeReuse(t, acc.indices, for_loop_stack, for_touch_regions); + + acc_feas.emplace_back(); + BufferAccessFeature& acc_fea = acc_feas.back(); + + acc_fea.tensor_name = t->op->func_name(); + acc_fea.acc_type = acc.acc_type; + acc_fea.stride = stride; + acc_fea.bytes = bytes; + acc_fea.unique_bytes = unique_bytes; + acc_fea.lines = lines; + acc_fea.unique_lines = unique_lines; + acc_fea.reuse_type = reuse_type; + acc_fea.reuse_dis_iter = reuse_dis_iter; + acc_fea.reuse_dis_bytes = reuse_dis_bytes; + acc_fea.reuse_ct = reuse_ct; + if (acc_fea.reuse_ct > 0.5) { + acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; + acc_fea.lines_d_reuse_ct = lines / reuse_ct; + acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; + } else { + // no reuse, multiply by a magic number '2' + acc_fea.bytes_d_reuse_ct = bytes * 2; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; + acc_fea.lines_d_reuse_ct = lines * 2; + acc_fea.unique_lines_d_reuse_ct = unique_lines* 2; + } + } + + fea.access_feas = acc_feas; + } + + // TODO(...): RealizeNode is deprecated, move to BufferRealizeNode + void VisitStmt_(const RealizeNode *node) final { + StmtExprVisitor::VisitStmt_(node); + + te::Operation op = Downcast(node->func); + te::Tensor ten = op.output(node->value_index); + + FeatureSet& fea = op_features[ten]; + + float allocation_size = 1.0f; + for (const auto& x : node->bounds) { + allocation_size *= GetIntImm(x->extent); + } + // allocation feature + fea.alloc_size = allocation_size * ten->dtype.bytes(); + fea.alloc_prod = allocation_size * outer_loop_prod; + fea.alloc_outer_prod = outer_loop_prod; + fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod; + } + + float outer_loop_prod = 1.0f; + + std::vector for_loop_stack; + std::vector parallel_for_stack; + std::vector vec_for_stack; + std::vector unroll_for_stack; + + bool is_gpu; + int blockIdx_x_len{1}; + int blockIdx_y_len{1}; + int blockIdx_z_len{1}; + int threadIdx_x_len{1}; + int threadIdx_y_len{1}; + int threadIdx_z_len{1}; + int vthread_len{1}; + int16_t cur_auto_unroll_max_step{0}; + + std::unordered_map op_features; + + // for a loop, for all its touched tensors, for all different accesses to the tensors, + // its (access type, number of touched elements, number of bytes of single element) + std::unordered_map > > > for_touch_regions; + + private: + const int cache_line_size_ = 64; +}; + +// shifted log to incorporate the property that slog(0) = 0 +inline float slog(float x) { + return x < 0 ? -std::log2(-x+1) : std::log2(x+1); +} + +// Get features for all ir::Provide statements in a TVM program. +// So we call it `PerStmt` feature +void GetPerStmtFeature(const Stmt& stmt, + int cache_line_size, + int max_n_bufs, + std::vector* ret) { + LOG(WARNING) << "RealizeNode & ProvideNode deprecated, " + << "need to fix the implementation of PerStmtFeatureExtractor."; + PerStmtFeatureExtractor extractor(cache_line_size); + extractor(stmt); + + ret->push_back(extractor.op_features.size()); + + for (const auto& x : extractor.op_features) { + const FeatureSet& fea_set = x.second; + + /***** compute feature *****/ + ret->push_back(slog(fea_set.float_mad)); + ret->push_back(slog(fea_set.float_addsub)); + ret->push_back(slog(fea_set.float_mul)); + ret->push_back(slog(fea_set.float_divmod)); + ret->push_back(slog(fea_set.float_cmp)); + ret->push_back(slog(fea_set.float_math_func)); + ret->push_back(slog(fea_set.float_other_func)); + ret->push_back(slog(fea_set.int_mad)); + ret->push_back(slog(fea_set.int_addsub)); + ret->push_back(slog(fea_set.int_mul)); + ret->push_back(slog(fea_set.int_divmod)); + ret->push_back(slog(fea_set.int_cmp)); + ret->push_back(slog(fea_set.int_math_func)); + ret->push_back(slog(fea_set.int_other_func)); + ret->push_back(slog(fea_set.bool_op)); + ret->push_back(slog(fea_set.select_op)); + + ret->push_back(slog(fea_set.vec_num)); + ret->push_back(slog(fea_set.vec_prod)); + ret->push_back(slog(fea_set.vec_len)); + for (int i = 0; i <= kPosMixed; i++) { + ret->push_back(i == fea_set.vec_type); + } + + ret->push_back(slog(fea_set.unroll_num)); + ret->push_back(slog(fea_set.unroll_prod)); + ret->push_back(slog(fea_set.unroll_len)); + for (int i = 0; i <= kPosMixed; i++) { + ret->push_back(i == fea_set.unroll_type); + } + + ret->push_back(slog(fea_set.parallel_num)); + ret->push_back(slog(fea_set.parallel_prod)); + ret->push_back(slog(fea_set.parallel_len)); + for (int i = 0; i <= kPosMixed; i++) { + ret->push_back(i == fea_set.parallel_type); + } + + ret->push_back(fea_set.is_gpu); + ret->push_back(slog(fea_set.blockIdx_x_len)); + ret->push_back(slog(fea_set.blockIdx_y_len)); + ret->push_back(slog(fea_set.blockIdx_z_len)); + ret->push_back(slog(fea_set.threadIdx_x_len)); + ret->push_back(slog(fea_set.threadIdx_y_len)); + ret->push_back(slog(fea_set.threadIdx_z_len)); + ret->push_back(slog(fea_set.vthread_len)); + + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + ret->push_back(fea_set.arith_intensity_curve[i]); + } + + /***** access feature *****/ + // sort according to pair (lines, bytes) + std::vector > buf_order_key; + for (const auto& acc_fea : fea_set.access_feas) { + buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); + } + std::vector buf_order(buf_order_key.size()); + std::iota(buf_order.begin(), buf_order.end(), 0); + + auto cmp = [&buf_order_key](int l, int r) { + return buf_order_key[l].first > buf_order_key[r].first + || (buf_order_key[l].first == buf_order_key[r].first + && buf_order_key[l].second > buf_order_key[r].second); + }; + std::sort(buf_order.begin(), buf_order.end(), cmp); + int n_bufs = std::min(max_n_bufs, static_cast(buf_order.size())); + buf_order.resize(n_bufs); + + for (int idx : buf_order) { + const auto& acc_fea = fea_set.access_feas[idx]; + for (int j = 0; j <= kReadWrite; ++j) { + ret->push_back(j == acc_fea.acc_type); + } + ret->push_back(slog(acc_fea.bytes)); + ret->push_back(slog(acc_fea.unique_bytes)); + ret->push_back(slog(acc_fea.lines)); + ret->push_back(slog(acc_fea.unique_lines)); + for (int j = 0; j <= kNoReuse; ++j) { + ret->push_back(acc_fea.reuse_type == j); + } + ret->push_back(slog(acc_fea.reuse_dis_iter)); + ret->push_back(slog(acc_fea.reuse_dis_bytes)); + ret->push_back(slog(acc_fea.reuse_ct)); + ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); + ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); + ret->push_back(slog(acc_fea.lines_d_reuse_ct)); + ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); + ret->push_back(slog(acc_fea.stride)); + } + // - fill padding + for (int i = 0; i < max_n_bufs - n_bufs; ++i) { + for (int j = 0; j <= kReadWrite; ++j) { // 3 + ret->push_back(0.0f); + } + ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); + for (int j = 0; j <= kNoReuse; ++j) { // 3 + ret->push_back(0.0f); + } + ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); + ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); + } + + /***** allocation feature *****/ + ret->push_back(slog(fea_set.alloc_size)); + ret->push_back(slog(fea_set.alloc_prod)); + ret->push_back(slog(fea_set.alloc_outer_prod)); + ret->push_back(slog(fea_set.alloc_inner_prod)); + + /***** overall feature *****/ + ret->push_back(slog(fea_set.outer_prod)); + ret->push_back(slog(fea_set.num_loops)); + ret->push_back(slog(fea_set.auto_unroll_max_step)); + } +} + + +/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ +void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret) { + /***** compute feature *****/ + ret->push_back(("float_mad")); + ret->push_back(("float_addsub")); + ret->push_back(("float_mul")); + ret->push_back(("float_divmod")); + ret->push_back(("float_cmp")); + ret->push_back(("float_mathfunc")); + ret->push_back(("float_otherfunc")); + ret->push_back(("int_mad")); + ret->push_back(("int_addsub")); + ret->push_back(("int_mul")); + ret->push_back(("int_divmod")); + ret->push_back(("int_cmp")); + ret->push_back(("int_mathfunc")); + ret->push_back(("int_otherfunc")); + ret->push_back(("bool_op")); + ret->push_back(("select_op")); + ret->push_back(("vec_num")); + ret->push_back(("vec_prod")); + ret->push_back(("vec_len")); + ret->push_back(("vec_type.kPosNone")); + ret->push_back(("vec_type.kPosInnerSpatial")); + ret->push_back(("vec_type.kPosMiddleSpatial")); + ret->push_back(("vec_type.kPosOuterSpatial")); + ret->push_back(("vec_type.kPosInnerReduce")); + ret->push_back(("vec_type.kPosMiddleReduce")); + ret->push_back(("vec_type.kPosOuterReduce")); + ret->push_back(("vec_type.kPosMixed")); + ret->push_back(("unroll_num")); + ret->push_back(("unroll_prod")); + ret->push_back(("unroll_len")); + ret->push_back(("unroll_type.kPosNone")); + ret->push_back(("unroll_type.kPosInnerSpatial")); + ret->push_back(("unroll_type.kPosMiddleSpatial")); + ret->push_back(("unroll_type.kPosOuterSpatial")); + ret->push_back(("unroll_type.kPosInnerReduce")); + ret->push_back(("unroll_type.kPosMiddleReduce")); + ret->push_back(("unroll_type.kPosOuterReduce")); + ret->push_back(("unroll_type.kPosMixed")); + ret->push_back(("parallel_num")); + ret->push_back(("parallel_prod")); + ret->push_back(("parallel_len")); + ret->push_back(("parallel_type.kPosNone")); + ret->push_back(("parallel_type.kPosInnerSpatial")); + ret->push_back(("parallel_type.kPosMiddleSpatial")); + ret->push_back(("parallel_type.kPosOuterSpatial")); + ret->push_back(("parallel_type.kPosInnerReduce")); + ret->push_back(("parallel_type.kPosMiddleReduce")); + ret->push_back(("parallel_type.kPosOuterReduce")); + ret->push_back(("parallel_type.kPosMixed")); + ret->push_back(("is_gpu")); + ret->push_back(("blockIdx_x_len")); + ret->push_back(("blockIdx_y_len")); + ret->push_back(("blockIdx_z_len")); + ret->push_back(("threadIdx_x_len")); + ret->push_back(("threadIdx_y_len")); + ret->push_back(("threadIdx_z_len")); + ret->push_back(("vthread_len")); + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + ret->push_back(("arith_intensity_curve_" + std::to_string(i))); + } + // section total: 55 + ARITH_INTENSITY_CURVE_SAMPLE_N = 65 + + /***** access feature *****/ + for (size_t i = 0; i < static_cast(max_n_bufs); ++i) { + std::string prefix = "B" + std::to_string(i) + "."; + ret->push_back((prefix + "acc_type.kRead")); + ret->push_back((prefix + "acc_type.kWrite")); + ret->push_back((prefix + "acc_type.kReadWrite")); + ret->push_back((prefix + "bytes")); + ret->push_back((prefix + "unique_bytes")); + ret->push_back((prefix + "lines")); + ret->push_back((prefix + "unique_lines")); + ret->push_back((prefix + "reuse_type.kLoopMultipleRead")); + ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite")); + ret->push_back((prefix + "reuse_type.kNoReuse")); + ret->push_back((prefix + "reuse_dis_iter")); + ret->push_back((prefix + "reuse_dis_bytes")); + ret->push_back((prefix + "reuse_ct")); + ret->push_back((prefix + "bytes_d_reuse_ct")); + ret->push_back((prefix + "unique_bytes_d_reuse_ct")); + ret->push_back((prefix + "lines_d_reuse_ct")); + ret->push_back((prefix + "unique_lines_d_reuse_ct")); + ret->push_back((prefix + "stride")); + } + // section total : max_n_bufs * 18 + + /***** allocation feature *****/ + ret->push_back(("alloc_size")); + ret->push_back(("alloc_prod")); + ret->push_back(("alloc_outer_prod")); + ret->push_back(("alloc_inner_prod")); + // section total : 4 + + /***** overall feature *****/ + ret->push_back(("outer_prod")); + ret->push_back(("num_loops")); + ret->push_back(("auto_unroll_max_step")); + // section total : 2 +} + +void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, + int max_n_bufs, std::vector* feature, std::atomic* error_ct) { + te::Schedule sch; + Array tensors; + Map bounds; + GlobalVar g("main"); + + std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + sch = sch.normalize(); + bounds = te::InferBound(sch); + + try { + auto stmt = te::ScheduleOps(sch, bounds, false); + Map out_binds; Array out_arg_list; + bool compact = te::VerifyCompactBuffer(stmt); + GetBinds(tensors, compact, std::unordered_map(), + &out_binds, &out_arg_list, BuildConfig::Create()); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, + std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String("main")); + auto mod = IRModule(Map({{g, f}})); + auto pass_list = Array(); + if (task->target->device_type == kDLGPU) { + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64)); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::VectorizeLoop()); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::Simplify()); + tvm::Map gpu_params { + {"max_shared_memory_per_block", + task->hardware_params->max_shared_memory_per_block}, + {"max_local_memory_per_block", + task->hardware_params->max_registers_per_block}, + {"max_threads_per_block", + task->hardware_params->max_threads_per_block}, + {"max_vector_bytes", + task->hardware_params->vector_unit_bytes} + }; + pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params)); + const auto& optimize = tir::transform::Sequential(pass_list); + optimize(mod); + } + pass_list.clear(); + pass_list.push_back(tir::transform::Simplify()); + const auto& optimize = tir::transform::Sequential(pass_list); + mod = optimize(std::move(mod)); + const auto& it = mod->functions.find(g); + CHECK(it != mod->functions.end()); + const auto& prim_func = (*it).second.as(); + GetPerStmtFeature(prim_func->body, + task->hardware_params->cache_line_bytes, + max_n_bufs, feature); + } catch (dmlc::Error &e) { + (*error_ct)++; + } +} + +void GetPerStmtFeaturesFromStates(const Array& states, + const SearchTask& task, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features) { + // extract features + features->assign(states.size(), std::vector()); + + std::atomic error_ct(0); + + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); + for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { + pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], + max_n_bufs, &(*features)[i], &error_ct); + } + pool.WaitBatch(); + + if (error_ct > 0) { + std::cerr << "Encountered " << error_ct + << " errors during feature extraction. Ignored." << std::endl; + } +} + + +void GetPerStmtFeaturesFromStates(const Array& states, + const std::vector& tasks, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features) { + // extract features + features->assign(states.size(), std::vector()); + + std::atomic error_ct(0); + + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); + for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { + pool.Enqueue(GetPerStmtFeaturesWorkerFunc, tasks[i], states[i], + max_n_bufs, &(*features)[i], &error_ct); + } + pool.WaitBatch(); + + if (error_ct > 0) { + std::cerr << "Encountered " << error_ct + << " errors during feature extraction. Ignored." << std::endl; + } +} + +void GetPerStmtFeaturesFromFile(const std::string& filename, + int n_lines, + int max_n_bufs, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids) { + Array states; + // ArrayNode* pstates = states.CopyOnWrite(); + std::vector tasks; + + normalized_throughputs->clear(); + task_ids->clear(); + + // (workload_key, target) -> (search_task, task_id) + std::unordered_map, std::pair> task_cache; + // task_id -> min_cost + std::vector min_costs; + + // read from file + LogReader reader = LogReaderNode::make(filename); + auto cur_inp = make_object(); + auto cur_res = make_object(); + while (reader->ReadNext(cur_inp.get(), cur_res.get())) { + float cost = static_cast(FloatArrayMean(cur_res->costs)); + const std::string& workload_key = cur_inp->task->workload_key; + + SearchTask task; + size_t task_id; + std::pair key(workload_key, cur_inp->task->target->str()); + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + // rebuild task + task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), + workload_key, + cur_inp->task->target, + cur_inp->task->target_host, + cur_inp->task->hardware_params); + task_id = task_cache.size(); + + // compute min cost for each task + task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); + min_costs.push_back(cost); + } else { + std::tie(task, task_id) = find_res->second; + min_costs[task_id] = std::min(min_costs[task_id], cost); + } + + tasks.push_back(std::move(task)); + task_ids->push_back(task_id); + // pstates->data.push_back(cur_inp->state); + states.push_back(cur_inp->state); + normalized_throughputs->push_back(cost); + + if (n_lines > 0 && static_cast(states.size()) >= n_lines) { + break; + } + } + + for (size_t i = 0; i < normalized_throughputs->size(); ++i) { + (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; + } + + GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, 0, features); +} + +void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, + const Array& results, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids) { + Array states; + // ArrayNode* pstates = states.CopyOnWrite(); + std::vector tasks; + + normalized_throughputs->clear(); + task_ids->clear(); + + // (workload_key, target) -> (search_task, task_id) + std::unordered_map, std::pair> task_cache; + // task_id -> min_cost + std::vector min_costs; + + tasks.reserve(inputs.size()); + normalized_throughputs->reserve(inputs.size()); + task_ids->reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + float cost = static_cast(FloatArrayMean(results[i]->costs)); + const std::string& workload_key = inputs[i]->task->workload_key; + SearchTask task; + + size_t task_id; + std::pair key(workload_key, inputs[i]->task->target->str()); + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete + task = inputs[i]->task; + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), + workload_key, + inputs[i]->task->target, + inputs[i]->task->target_host, + inputs[i]->task->hardware_params); + } + task_id = task_cache.size(); + + // compute min cost for each task + task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); + min_costs.push_back(cost); + } else { + std::tie(task, task_id) = find_res->second; + min_costs[task_id] = std::min(min_costs[task_id], cost); + } + + tasks.push_back(std::move(task)); + task_ids->push_back(task_id); + // pstates->data.push_back(inputs[i]->state); + states.push_back(inputs[i]->state); + normalized_throughputs->push_back(cost); + } + + for (size_t i = 0; i < normalized_throughputs->size(); ++i) { + (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; + } + + GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, + skip_first_n_feature_extraction, features); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/feature.h b/src/ansor/feature.h new file mode 100644 index 000000000000..149c59e8cb7d --- /dev/null +++ b/src/ansor/feature.h @@ -0,0 +1,63 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Meta inforamtion for a search task + */ + +#ifndef TVM_ANSOR_FEATURE_H_ +#define TVM_ANSOR_FEATURE_H_ + +// #include +#include +#include +#include "compute_dag.h" +#include "measure.h" + +namespace tvm { +namespace ansor { + +/*! \brief Get PerStmt feature from a tvm IR stmt */ +void GetPerStmtFeature(const Stmt& stmt, + int cache_line_size, + int max_n_bufs, + std::vector* ret); + +/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ +void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret); + + +/*! \brief Get PerStmt feature from states */ +void GetPerStmtFeaturesFromStates(const Array& states, + const SearchTask& task, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features); + +/*! \brief Get PerStmt feature from states */ +void GetPerStmtFeaturesFromStates(const Array& states, + const std::vector& tasks, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features); + +/*! \brief Get PerStmt feature from a log file */ +void GetPerStmtFeaturesFromFile(const std::string& filename, + int n_lines, + int max_n_bufs, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids); + +/*! \brief Get PerStmt feature from measure pairs */ +void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, + const Array& results, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_FEATURE_H_ diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc new file mode 100644 index 000000000000..b3b93ec9c839 --- /dev/null +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -0,0 +1,1420 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "meta_tile_rewrite_policy.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#define IS_GPU(task) ((task)->target->device_type == kDLGPU || \ + (task)->target->device_type == kDLOpenCL) + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode); + +// All possible candidates for auto_unroll +const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; + +SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, + Map params, + int seed) { + auto node = make_object(); + node->program_cost_model = std::move(program_cost_model); + node->rand_gen_ = std::mt19937(seed); + node->params = std::move(params); + return SearchPolicy(node); +} + +State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) { + std::vector best_states, random_states; + cur_task_ = task; + verbose_ = verbose; + num_measure_per_iter_ = num_measure_per_iter; + + if (n_trials <= 1) { // no measurement is allowed + SearchOneRound(&best_states, 0, &random_states); + CHECK_GT(best_states.size(), 0); + return best_states[0]; + } else { + std::vector inputs; + std::vector results; + int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter); + + measurer->Reset(); + + early_stopping = early_stopping < 0 ? std::numeric_limits::max() >> 1 : early_stopping; + + int ct = 0; + while (ct < n_trials) { + if (!inputs.empty()) { + // retrain cost models + PrintTitle("Train cost model", verbose_); + program_cost_model->Update(inputs, results); + } + + // Search one round to get promising states + PrintTitle("Search", verbose_); + SearchOneRound(&best_states, num_random, &random_states); + + // Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check + cur_task_->compute_dag.InferBound(&best_states); + cur_task_->compute_dag.InferBound(&random_states); + + // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state + // Also pick some random states to do eps-greedy + PickStatesWithEpsGreedy(&inputs, best_states, random_states, n_trials - ct); + + // Have traversed all of search space + if (inputs.empty()) { + StdCout(verbose) << "All candidates in the search space have been measured." << std::endl; + break; + } + + // Measure candidate states + PrintTitle("Measure", verbose_); + measurer->Measure(cur_task_, GetRef(this), inputs, &results); + ct += inputs.size(); + + if (ct - measurer->best_ct[cur_task_->workload_key] > early_stopping) { + StdCout(verbose) << "Meet the early stopping condition." << std::endl; + break; + } + + // Update measured states. These states will join the LocalMutation in later rounds + for (const auto& res : results) { + measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); + } + } + PrintTitle("Done", verbose_); + + return measurer->best_state[cur_task_->workload_key]; + } +} + +std::pair, Array > + MetaTileRewritePolicyNode::ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { + if (cur_task_.defined()) { + CHECK_EQ(cur_task_, task); + } else { + cur_task_ = task; + } + verbose_ = verbose; + num_measure_per_iter_ = num_measure; + + std::vector best_states, random_states; + std::vector inputs; + std::vector results; + int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure); + + // Search one round to get promising states + PrintTitle("Search", verbose); + SearchOneRound(&best_states, num_random * 2, &random_states); + + // Fill correct bound. This is necessary for computing the correct ToStr() for reduncency check + cur_task_->compute_dag.InferBound(&best_states); + cur_task_->compute_dag.InferBound(&random_states); + + // Pick `num_measure` states to measure, check hash to remove already measured state + // Also pick some random states to do eps-greedy + PickStatesWithEpsGreedy(&inputs, best_states, random_states, num_measure); + + // Measure candidate states + PrintTitle("Measure", verbose); + measurer->Measure(cur_task_, GetRef(this), inputs, &results); + + // Update throughputs of measured states. These states will join the LocalMutation in later rounds + for (const auto& res : results) { + measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); + } + + // Update the cost model + Array inputs_arr(std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); + Array results_arr(std::make_move_iterator(results.begin()), + std::make_move_iterator(results.end())); + + PrintTitle("Train cost model", verbose); + program_cost_model->Update(inputs_arr, results_arr); + return std::make_pair(std::move(inputs_arr), std::move(results_arr)); +} + +void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( + std::vector* inputs, + const std::vector& best_states, + const std::vector& random_states, + int remaining_n_trials) { + int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter_); + int num_good = num_measure_per_iter_ - num_random; + + inputs->clear(); + size_t offset_best = 0, offset_random = 0; + + while (static_cast(inputs->size()) < std::min(num_measure_per_iter_, remaining_n_trials)) { + const State* pstate; + + bool has_best = offset_best < best_states.size(); + bool has_random = offset_random < random_states.size(); + + if (static_cast(inputs->size()) < num_good) { + // prefer best states + if (has_best) { + pstate = &best_states[offset_best++]; + } else if (has_random) { + pstate = &random_states[offset_random++]; + } else { + break; + } + } else { + // prefer random states + if (has_random) { + pstate = &random_states[offset_random++]; + } else if (has_best) { + pstate = &best_states[offset_best++]; + } else { + break; + } + } + + // Check if it has already been measured + std::string state_str = pstate->ToStr(); + + if (measured_states_set_.count(state_str)) { continue; } + measured_states_set_.insert(state_str); + + inputs->push_back(MeasureInputNode::make(cur_task_, *pstate)); + measured_states_vector_.push_back(std::move(*pstate)); + } +} + +void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, + int num_random_states, std::vector* random_states) { + best_states->clear(); + random_states->clear(); + + // Get parameters + int population = GetIntParam(params, "evolutionary_search_population"); + int num_use_measured = std::min(static_cast(measured_states_vector_.size()), + static_cast( + GetDoubleParam(params, "evolutionary_search_use_measured_ratio") * population)); + bool have_cost_model = !program_cost_model->IsInstance(); + + if (!have_cost_model) { + num_use_measured = 0; + } + + // Synthesize meta structure + std::vector meta_structures; + SynthesizeMetaStructure(&meta_structures); + + // PrintAllStates(meta_structures); + // exit(0); + + // Sample the init population + std::vector init_population; + SampleInitPopulation(meta_structures, population - num_use_measured, &init_population); + + // PrintAllStates(init_population); + // exit(0); + + if (have_cost_model) { + // Also insert already measured good states to the initial population + std::vector indices; + Argsort(measured_states_throughputs_, &indices); + for (int i = 0; i < num_use_measured; i++) { + init_population.push_back(measured_states_vector_[indices[i]]); + } + + // Perform evolutionary search + EvolutionarySearch(init_population, num_measure_per_iter_ * 2, best_states); + } else { + // If the cost model is useless (i.e. RandomCostModel), skip evolutionary search + RandomSampleStates(init_population, &rand_gen_, num_measure_per_iter_ * 3, best_states); + } + + // Sample some random states for eps-greedy + RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); +} + +// The baseclass of derivation rules used in meta structure synthesis +class StructureSynthesisRule { + public: + enum ConditionEnum { + kPass, kApply, kApplyAndSkipRest + }; + + virtual ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) = 0; + virtual std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) = 0; +}; + +static inline bool ShouldBeCacheRead( + const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + if (HasAttrsFlag(state, stage_id, + SearchPolicyNode::no_cache_read_key)) { + return false; + } + + std::unordered_set consumers; + GetConsumers(task, state, stage->op, &consumers); + if (consumers.size() != 1) { + return false; + } + + int target_stage_id = OperationToStage(*consumers.begin(), state); + if (!NeedsMultilevelTiling(task, state, + state->stages[target_stage_id]->op)) { + return false; + } + + std::unordered_set producers; + GetProducers(task, state, state->stages[target_stage_id]->op, &producers); + // Only those directly mapped stages can do CacheRead + if (producers.find(stage->op) == producers.end()) { + return false; + } + + return true; +} + +static inline bool ShouldAlwaysBeInlined( + const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + if (stage->op->IsInstance()) { + return false; + } + + // Inline limitation of TVM + if (!IsOutputOp(task, state, stage->op) && !HasReduceIter(stage)) { + // Always inline condition: + // 1. Has attrs that this must be inlined + // 2. Analyse shows this is strict inlineable + // 3. A GPU stage can be inlined(If it should be cache read, do it first) + if (HasAttrsFlag(state, stage_id, + SearchPolicyNode::always_compute_inline_key) || + IsStrictInlineable(task, state, stage->op) || + (IS_GPU(policy->cur_task_) && + !ShouldBeCacheRead(policy, state, stage_id))) { + return true; + } + } + + return false; +} + +// The rule that inlines simple elementwise ops +class RuleAlwaysInline : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + return ShouldAlwaysBeInlined(policy, state, stage_id) ? + kApplyAndSkipRest : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + State tmp_s = state; + tmp_s.compute_inline(stage_id); + return {std::make_pair(std::move(tmp_s), stage_id - 1)}; + } +}; + +// The rule that simply skip the current stage +class RuleSkipStage : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + const auto& attrs = stage->op->attrs; + if ((attrs.count(SearchPolicyNode::no_split_at_inner_key) || + attrs.count(SearchPolicyNode::no_split_at_outer_key)) && + NeedsMultilevelTiling(task, state, stage->op)) { + // for the transform stages in Winograd + return kPass; + } + + return kApply; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + return {std::make_pair(state, stage_id - 1)}; + } +}; + +// The rule that performs multi-level tiling +class RuleMultiLevelTiling : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + return NeedsMultilevelTiling(task, state, stage->op) ? + (IS_GPU(policy->cur_task_) ? kApplyAndSkipRest : kApply) : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : + GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); + + std::vector spatial_split_step_ids; + State tmp_s = state; + tmp_s = DoMultiLevelTiling(tmp_s, stage_id, multi_level_tiling_structure, + &spatial_split_step_ids); + return {std::make_pair(std::move(tmp_s), stage_id-1)}; + } +}; + +// The rule that performs multi-level tiling and fuses later consumers +class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + int target_stage_id; + + if (IS_GPU(policy->cur_task_)) { + return NeedsMultilevelTiling(task, state, stage->op) && + HasSingleElementwiseMatchedConsumer(task, state, stage, + &target_stage_id) && + (!HasCacheReadStage(state, stage_id) || + HasCacheWriteStage(state, stage_id)) ? + kApplyAndSkipRest : kPass; + } + + return NeedsMultilevelTiling(task, state, stage->op) && + HasSingleElementwiseMatchedConsumer(task, state, stage, + &target_stage_id) ? + kApply : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : + GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); + + std::vector spatial_split_step_ids; + int target_stage_id; + std::unordered_set consumers; + + GetConsumers(task, state, state->stages[stage_id]->op, &consumers); + CHECK(HasSingleElementwiseMatchedConsumer(task, state, stage, &target_stage_id)); + + State base_state = state; + base_state = DoMultiLevelTiling(base_state, stage_id, + multi_level_tiling_structure, &spatial_split_step_ids); + std::vector follow_tiling_levels; + if (IS_GPU(policy->cur_task_)) { + follow_tiling_levels.push_back(3); + } else { + follow_tiling_levels.push_back(1); + follow_tiling_levels.push_back(2); + } + + std::vector > ret; + for (int level : follow_tiling_levels) { + if (tolower(multi_level_tiling_structure[level-1]) != 's') { + continue; + } + State tmp_s = base_state; + tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level); + const Iterator &target_iter = tmp_s->stages[target_stage_id]->iters[ + level * spatial_split_step_ids.size() - 1]; + tmp_s.compute_at(stage_id, target_stage_id, target_iter); + + ret.emplace_back(std::move(tmp_s), stage_id - 1); + } + + return ret; + } +}; + +// The rule that adds a cache write stage +class RuleAddCacheWrite : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + int target_stage_id; + + // Add cache write if a stage needs multi-level tiling, + // but does not have a element-wise matched consumer + return NeedsMultilevelTiling(task, state, stage->op) && + !HasAttrsFlag(state, stage_id, SearchPolicyNode::no_cache_write_key) && + (!HasSingleElementwiseMatchedConsumer(task, state, stage, + &target_stage_id) || + (HasCacheReadStage(state, stage_id) && + !HasCacheWriteStage(state, stage_id))) ? + kApply : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + + State tmp_s = state; + tmp_s.cache_write(stage_id, "local", task->compute_dag); + return {std::make_pair(std::move(tmp_s), stage_id)}; + } +}; + +// The rule that adds a cache read stage +// Mainly used for GPU cooperative fetching +// Currently only support 1 to 1 match cache read +class RuleAddCacheRead : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + return ShouldBeCacheRead(policy, state, stage_id) ? + kApplyAndSkipRest : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + std::unordered_set consumers; + GetConsumers(task, state, stage->op, &consumers); + CHECK_EQ(consumers.size(), 1); + int target_stage_id = OperationToStage(*consumers.begin(), state); + State tmp_s = state; + int added_stage_id = tmp_s.cache_read(stage_id, "shared", + {target_stage_id}, + task->compute_dag); + target_stage_id++; + const auto& share_read_pos = GetLastReduceIteratorInOutermostReduceTile( + tmp_s->stages[target_stage_id]); + tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos); + + return {std::make_pair(std::move(tmp_s), stage_id)}; + } +}; + +// The rule that adds rfactor stage +class RuleAddRfactor : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + return NeedsRfactor(task, state, stage->op) && + !HasCacheWriteStage(state, stage_id) ? + kApply : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + std::vector > ret; + + State tmp_s = state; + + // fuse reduce iters + std::vector space_iters, reduce_iters; + for (const auto &iter : stage->iters) { + if (iter->iter_type == kSpace) { + space_iters.push_back(iter); + } else if (iter->iter_type == kReduce) { + reduce_iters.push_back(iter); + } + } + CHECK(!reduce_iters.empty()); + Iterator fused_reduce_iter; + if (reduce_iters.size() > 1) { + fused_reduce_iter = tmp_s.fuse(stage_id, reduce_iters); + } else { + fused_reduce_iter = reduce_iters[0]; + } + + // split reduce iters + const auto &split_res = tmp_s.split(stage_id, fused_reduce_iter, {1}); + int factor_axis_id = static_cast(space_iters.size()); + State base_state = tmp_s; + for (const auto &split_iter : split_res) { + tmp_s = base_state; + tmp_s.rfactor(stage_id, split_iter, factor_axis_id, task->compute_dag); + + // reorder the space iterator to innermost for vectorization + if (split_iter == split_res[1]) { + std::vector new_order; + for (size_t i = 0; i < tmp_s->stages[stage_id]->iters.size(); ++i) { + if (i != space_iters.size()) { + new_order.push_back(tmp_s->stages[stage_id]->iters[i]); + } + } + new_order.push_back(tmp_s->stages[stage_id]->iters[space_iters.size()]); + tmp_s.reorder(stage_id, new_order); + } + ret.emplace_back(std::move(tmp_s), stage_id - 1); + } + + return ret; + } +}; + +void MetaTileRewritePolicyNode::SynthesizeMetaStructure(std::vector* out_states) { + State init_state = cur_task_->compute_dag.GetInitState(); + std::string cpu_multi_level_tiling_structure = + GetStringParam(params, "cpu_multi_level_tiling_structure"); + + // two ping pong buffers to avoid copy + std::vector states_buf1, states_buf2; + std::vector *pnow, *pnext; + pnow = &states_buf1; + pnext = &states_buf2; + pnow->push_back(init_state); + + // A map that maps state to its current working position (stage_id) + std::unordered_map cur_stage_id_map; + cur_stage_id_map[init_state] = static_cast(init_state->stages.size() - 1); + + static RuleSkipStage rule_skip_stage; + static RuleAlwaysInline rule_always_inline; + static RuleMultiLevelTiling rule_multi_level_tiling; + static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion; + static RuleAddCacheWrite rule_add_cache_write_stage; + static RuleAddCacheRead rule_add_cache_read_stage; + static RuleAddRfactor rule_add_rfactor; + // We may apply and skip the rest when processing some rules, + // should take care of the rule vector order here + static std::vector all_rules { + &rule_always_inline, &rule_add_cache_write_stage, + &rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling, + &rule_add_rfactor, &rule_skip_stage + }; + if (IS_GPU(cur_task_)) { + // Try cache read first before cache write + all_rules.insert(all_rules.begin() + 1, &rule_add_cache_read_stage); + } + // TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor + + // Derivation rule based synthesizer + while (!pnow->empty()) { + pnext->clear(); + + for (const State& state : *pnow) { + int stage_id = cur_stage_id_map[state]; + + // Reaches to the terminal stage + if (stage_id < 0) { + out_states->push_back(state); + continue; + } + + // Try all derivation rules + for (const auto& rule : all_rules) { + auto rule_check = rule->MeetCondition(this, state, stage_id); + if (rule_check > StructureSynthesisRule::ConditionEnum::kPass) { + for (const auto& pair : rule->Apply(this, state, stage_id)) { + cur_stage_id_map[pair.first] = pair.second; + pnext->push_back(pair.first); + } + // Skip the reset rules + if (rule_check == StructureSynthesisRule::ConditionEnum::kApplyAndSkipRest) { + break; + } + } + } + } + + std::swap(pnow, pnext); + } + + // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(), + // so later we can sample random value for the split factor. + // Why don't we use Expr() when doing the split for rfactor at the first time? + // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM. + // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor + // in other stages + for (size_t i = 0; i < out_states->size(); ++i) { + auto pstate = (*out_states)[i].CopyOnWrite(); + for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) { + if (pstate->transform_steps[step_id]->IsInstance()) { + CHECK_GE(step_id, 1); + int split_step_id = step_id - 1; + auto step = pstate->transform_steps[split_step_id].as(); + CHECK(step != nullptr); + pstate->transform_steps[split_step_id] + = SplitStepNode::make(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, + step->inner_to_outer); + } + } + } + + StdCout(verbose_) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; +} + +int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen, + SplitFactorizationMemo* split_memo) { + for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { + if (auto ps = (*state)->transform_steps[step_id].as()) { + bool defined = true; + for (const PrimExpr& len : ps->lengths) { + if (!len.defined()) { + defined = false; + } + } + + if (defined) { + continue; + } + + int extent = GetIntImm(ps->extent); + const std::vector >& candidate_lens = + split_memo->GetFactorizationSchemes( + extent, ps->lengths.size(), + policy->cur_task_->hardware_params->max_innermost_split_factor); + + StateNode* pstate = state->CopyOnWrite(); + pstate->transform_steps[step_id] = SplitStepNode::make( + ps->stage_id, ps->iter_id, ps->extent, + candidate_lens[(*rand_gen)() % candidate_lens.size()], + ps->inner_to_outer); + } + } + + return 0; +} + +int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, + State* state) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + auto pop = stage->op.as(); + + if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { + continue; + } + + std::vector to_fuse; + + // This stage has not been tiled, but in GPU schedule, we must tile it + // to do thread binding + if (!HasSplitStep(*state, stage_id)) { + for (const auto& it : (*state)->stages[stage_id]->iters) { + if (it->iter_type == kReduce) { + break; + } + to_fuse.push_back(it); + } + const auto& fused_it = state->fuse(stage_id, to_fuse); + // Set default vthread=1 & threadIdx.x=default_warp_size + // EvolutionarySearch will try more possiblity + if (GetExtent(fused_it) <= + policy->cur_task_->hardware_params->warp_size) { + state->bind_thread(stage_id, fused_it, kThreadX); + } else { + const auto& split_its = state->split(stage_id, fused_it, + {1, policy->cur_task_->hardware_params->warp_size}); + state->bind_thread(stage_id, split_its[0], kBlockX); + state->bind_thread(stage_id, split_its[1], kVThread); + state->bind_thread(stage_id, split_its[2], kThreadX); + } + + continue; + } + + int total_space_extent = 1; + for (const auto& i : pop->root_iter_vars()) { + CHECK(i->dom.defined()); + const auto& pint = i->dom->extent.as(); + CHECK(pint); + total_space_extent *= pint->value; + } + + // TODO(..): Add ThreadBind support for rfactor + if (total_space_extent <= policy->cur_task_->hardware_params->warp_size) { + for (const auto& it : (*state)->stages[stage_id]->iters) { + if (it->iter_type == kReduce) { + break; + } + to_fuse.push_back(it); + } + const auto& fused_it = state->fuse(stage_id, to_fuse); + state->bind_thread(stage_id, fused_it, kThreadX); + + continue; + } + + // Fuse the outermost space tile as blockIdx + for (size_t i = 0; i < pop->axis.size(); i++) { + const auto& it = (*state)->stages[stage_id]->iters[i]; + if (!StringEndWith(it->name, ".0")) { + break; + } + to_fuse.push_back(it); + } + const auto& blockidx_it = state->fuse(stage_id, to_fuse); + state->bind_thread(stage_id, blockidx_it, kBlockX); + + // Fuse the second outermost space tile as vthread + to_fuse.clear(); + for (size_t i = 1; i < pop->axis.size() + 1; i++) { + const auto& it = (*state)->stages[stage_id]->iters[i]; + if (!StringEndWith(it->name, ".1")) { + break; + } + to_fuse.push_back((*state)->stages[stage_id]->iters[i]); + } + const auto& vthread_it = state->fuse(stage_id, to_fuse); + if (GetExtent(vthread_it) > + policy->cur_task_->hardware_params->max_vthread_extent) { + return -1; + } + state->bind_thread(stage_id, vthread_it, kVThread); + + // Fuse the third outermost space tile as threadIdx + to_fuse.clear(); + for (size_t i = 2; i < pop->axis.size() + 2; i++) { + const auto& it = (*state)->stages[stage_id]->iters[i]; + if (!StringEndWith(it->name, ".2")) { + break; + } + to_fuse.push_back((*state)->stages[stage_id]->iters[i]); + } + const auto& threadidx_it = state->fuse(stage_id, to_fuse); + if (GetExtent(threadidx_it) < + policy->cur_task_->hardware_params->warp_size) { + return -1; + } + state->bind_thread(stage_id, threadidx_it, kThreadX); + } + + return 0; +} + +int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, + State* state) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + // Do cooperative fetching with cache read stage + // For two stages: A -> B + // 1. A -> A_cache_read -> B + // * + // 2. A -> A_cache_write -> A_cache_read -> B + // * + if ((stage_id > 0 && HasCacheReadStage((*state), stage_id - 1) && + !HasCacheWriteStage((*state), stage_id - 1)) || + (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && + HasCacheWriteStage((*state), stage_id - 2))) { + // Get spatial_split_step_ids from the root stage + std::unordered_set consumers; + std::vector spatial_split_step_ids; + const Stage& target_stage = (*state)->stages[stage_id]; + GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); + CHECK_EQ(consumers.size(), 1); + int target_stage_id = OperationToStage(*consumers.begin(), (*state)); + GetSpaceSplitStepIds((*state), target_stage_id, &spatial_split_step_ids); + + // Fuse all axis to to do cooperative fetching + Iterator fused = state->fuse(stage_id, + (*state)->stages[stage_id]->iters); + // Left a vectorized cooperative fetching split placeholder + const auto& iters0 = state->split(stage_id, fused, {1}); + state->vectorize(stage_id, iters0[1]); + // Follow split to keep a same thread extent with the root stage + const auto& iters1 = state->follow_fused_split(stage_id, iters0[0], + spatial_split_step_ids, + 1, true); + state->bind_thread(stage_id, iters1[1], kThreadX); + } + } + + return 0; +} + +int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen) { + if(GetIntParam(policy->params, "disable_change_compute_location")) { + return 0; + } + + for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { + const Stage& stage = (*state)->stages[stage_id]; + + if (stage->op_type == kPlaceholder) { + continue; + } + + if (IsTiled(stage) || stage->compute_at == kInlined) { + continue; + } + + if (NeedsMultilevelTiling(policy->cur_task_, (*state), stage->op)) { + continue; + } + + std::unordered_set consumers; + + GetConsumers(policy->cur_task_, (*state), stage->op, &consumers); + if (consumers.empty()) { + continue; + } + + int target_stage_id; + if (consumers.size() == 1) { + target_stage_id = OperationToStage(*consumers.begin(), *state); + } else { + // check all consumers share a common root + int common_root_id = -1; + bool mismatch = false; + for (const auto& consumer : consumers) { + int consumer_stage_id = OperationToStage(consumer, *state); + int root_id = -1; + if ((*state)->stages[consumer_stage_id]->compute_at == kRoot) { + root_id = consumer_stage_id; + } else if ((*state)->stages[consumer_stage_id]->compute_at == kIter) { + root_id = (*state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; + } else { + LOG(FATAL) << "Invalid case"; + } + + if (common_root_id == -1) { + common_root_id = root_id; + } else { + if (common_root_id != root_id) { + mismatch = true; + break; + } + } + } + + if (mismatch) { + continue; + } + target_stage_id = common_root_id; + } + + const Stage& target_stage = (*state)->stages[target_stage_id]; + std::set to_unroll_name_set; + if (target_stage->op->attrs.count(policy->always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, + policy->always_unroll_key); + } + + std::vector > candidates; + bool target_compute_at_other = target_stage->compute_at == kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // enumerate compute_at location at target_stage + int ct = 0; + for (const auto& target_iter : target_stage->iters) { + if (target_iter->iter_type == kReduce) { + visited_reduce = true; + if (!target_is_tiled) { // do not go into reduce iter + break; + } + } else if (target_iter->iter_type == kSpace) { + if (visited_reduce) { // do not go into inner tile + break; + } + } + + if (to_unroll_name_set.count(target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_type == kSpace && + StrEndsWith(target_iter->name, ".0")) { + // skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, target_iter); + + if ((*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_stage_id, ct++))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at( + target_stage_id).first; + const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; + if (target_target_stage->op->attrs.count(policy->always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, + policy->always_unroll_key); + } else { + to_unroll_name_set.clear(); + } + + int ct = 0; + for (const auto& target_target_iter : target_target_stage->iters) { + if (target_target_iter->iter_type == kReduce || + (*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, ct++))) { + break; + } + + if (to_unroll_name_set.count(target_target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.push_back(std::make_pair(target_target_stage_id, target_target_iter)); + } + } + + int choice = (*rand_gen)() % (candidates.size() + 2); + + if (choice == 0) { + if (!HasReduceIter(stage)) { + state->compute_inline(stage_id); + } + } else if (choice == 1) { + state->compute_root(stage_id); + } else { + choice = choice - 2; + state->compute_at(stage_id, candidates[choice].first, candidates[choice].second); + } + } + + return 0; +} + +int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, + State* state) { + std::function annotate_parallel; + + annotate_parallel = [&annotate_parallel]( + const MetaTileRewritePolicyNode* policy, State* state, int stage_id, int iter_offset) { + const Stage& stage = (*state)->stages[stage_id]; + + std::vector to_fuse; + int64_t parallel_degree = 1; + + // strategy: try to fuse and parallel the outermost n iterators + // Stop if we meet reduce iterator or we have enough parallel degree + size_t iter_id = iter_offset; + for (; iter_id < stage->iters.size(); ++iter_id) { + const Iterator& it = stage->iters[iter_id]; + if (it->iter_type == kReduce || it->annotation != kNone) { + break; + } + + to_fuse.push_back(it); + parallel_degree *= GetExtent(it); + + if (parallel_degree > policy->cur_task_->hardware_params->num_cores * 16) { + break; + } + + if ((*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(stage_id, iter_id))) { + break; + } + } + + if (parallel_degree == 1) { + auto res = (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); + if (res != (*state)->attach_map->iter_to_attached_stages.end()) { + for (int attached_stage_id : res->second) { + annotate_parallel(policy, state, attached_stage_id, 0); + } + annotate_parallel(policy, state, stage_id, iter_id + 1); + } + } + + if (!to_fuse.empty()) { + if (to_fuse.size() == 1) { + state->parallel(stage_id, to_fuse[0]); + } else { + Iterator fused_iter = state->fuse(stage_id, to_fuse); + state->parallel(stage_id, fused_iter); + } + } + }; + + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { + continue; + } + + annotate_parallel(policy, state, stage_id, 0); + } + + return 0; +} + +int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + + if (stage->op_type == kPlaceholder) { + continue; + } + + // Skip cooperative fetching stage + if (IS_GPU(policy->cur_task_) && + HasCacheReadStage((*state), stage_id - 1)) { + continue; + } + + // try to fuse and vectorize the space iterators in the inner most tile + int cum_length_prod = 1; + + std::set to_unroll_name_set; + if (stage->op->attrs.count(policy->always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, + policy->always_unroll_key); + } + + int num_fusible = 0; + while (num_fusible < static_cast(stage->iters.size())) { + int iter_id = static_cast(stage->iters.size()) - 1 - num_fusible; + if ((*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(stage_id, iter_id))) { + break; + } + + const Iterator& it = stage->iters[iter_id]; + + // Stop if we meet a reduce iterator + if (it->iter_type == kReduce || it->annotation != kNone || + to_unroll_name_set.count(it->name)) { + break; + } + + // Stop if the memory access is not continuous (vectorizable) + // Note: The check is too hard, so we use heuristic here + if (IsTiled(stage) && num_fusible != 0) { + // If the stage is tiled, then the memory access must not be continuous + // for the innermost two iterators + break; + } + + cum_length_prod *= GetExtent(it); + if (cum_length_prod > policy->cur_task_->hardware_params->max_unroll_vec) { + break; + } + + num_fusible++; + } + + if (num_fusible > 1) { + num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse + } + + if (num_fusible == 1) { + state->vectorize(stage_id, stage->iters.back()); + } else if (num_fusible > 1) { + std::vector to_fuse(stage->iters.end() - num_fusible, + stage->iters.end()); + state->vectorize(stage_id, state->fuse(stage_id, to_fuse)); + } + } + + return 0; +} + +int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + + if (stage->op_type == kPlaceholder) { + continue; + } + + if (stage->op->attrs.count(policy->always_unroll_inner_key)) { + // Special unroll policy + auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, + policy->always_unroll_inner_key); + std::set visited_names; + + // Unroll the space iterators and reduce iterators listed in the attrs + // in the innermost tile + int n = static_cast(stage->iters.size()) - 1; + visited_names.clear(); + while (n >= 0) { + const Iterator& it = stage->iters[n]; + + // If we meet two iterators that come from a same original iterator, + // then we are out of the innermost tile + size_t size_before = visited_names.size(); + ExtractOriginalIterators(it->name, &visited_names); + if (size_before == visited_names.size()) { + break; + } + + std::set name; + ExtractOriginalIterators(it->name, &name); + if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) { + state->unroll(stage_id, it); + } + + n--; + } + } else if (stage->op->attrs.count(policy->always_unroll_key)) { + // Special unroll policy + auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, + policy->always_unroll_key); + + // Unroll the space iterators and reduce iterators listed in the attrs + int n = static_cast(stage->iters.size()) - 1; + while (n >= 0) { + const Iterator& it = stage->iters[n]; + if (to_unroll_name_set.count(it->name)) { + state->unroll(stage_id, it); + } + n--; + } + } else if (HasReduceIter(stage)) { + // use auto unroll for multi level tiled stage + int value = policy->auto_unroll_configs[ + (*rand_gen)() % policy->auto_unroll_configs.size()]; + state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], + std::string("auto_unroll_max_step") + "$" + std::to_string(value)); + } + } + + return 0; +} + +void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& meta_structures, + int out_size, std::vector* out_states) { + std::uniform_real_distribution<> dis(0.0, 1.0); + int continue_count = 0; + + // TODO(...): Maybe try muti thread here + while (static_cast(out_states->size()) < out_size && + continue_count < out_size * 10) { + State tmp_s = meta_structures[rand_gen_() % meta_structures.size()]; + + InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); + + if (IS_GPU(cur_task_)) { + tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + + if (InitPopulationThreadBind(this, &tmp_s)) { + continue_count++; + continue; + } + + InitPopulationCooperativeFetching(this, &tmp_s); + } else { + InitPopulationChangeComputeLocation(this, &tmp_s, &rand_gen_); + + tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + + InitPopulationParallel(this, &tmp_s); + } + + InitPopulationVectorization(this, &tmp_s, &rand_gen_); + + InitPopulationUnroll(this, &tmp_s, &rand_gen_); + + out_states->push_back(std::move(tmp_s)); + } + + StdCout(verbose_) << "Sample Initial Population\t\t#s: " + << out_states->size() << std::endl; +} + +void MetaTileRewritePolicyNode::EvolutionarySearch( + const std::vector& init_population, + int num_best_states, std::vector* best_states) { + auto tic_begin = std::chrono::high_resolution_clock::now(); + + // Set parameters for genetic algorithm + int population = GetIntParam(params, "evolutionary_search_population"); + int num_iters = GetIntParam(params, "evolutionary_search_num_iters"); + double mutation_prob = GetDoubleParam(params, "evolutionary_search_mutation_prob"); + int num_cross_over = static_cast(population * 0.0); // NOT IMPLEMENTED currently + int num_cross_over_trial_upper_bound = num_cross_over * 3; + CostModel cost_model = program_cost_model; + + // Two ping pong buffers to avoid copy + std::vector states_buf1, states_buf2; + std::vector *pnow = &states_buf1, *pnext = &states_buf2; + states_buf1.reserve(population); + states_buf2.reserve(population); + states_buf1.insert(states_buf1.begin(), init_population.begin(), init_population.end()); + + // A heap to keep the best states during evolution + using StateItem = std::pair; + auto cmp = [](const StateItem& left, const StateItem& right) { + return left.second > right.second; + }; + std::vector heap; + std::unordered_set in_heap(measured_states_set_); + heap.reserve(num_best_states); + + // auxiliary global variables + std::vector scores; + std::vector prefix_sum_probs; + double max_score = 0.0; + scores.reserve(population); + prefix_sum_probs.reserve(population); + std::uniform_real_distribution<> dis(0.0, 1.0); + int mutation_fail_ct = 0; + + // Genetic Algorithm + for (int k = 0; k < num_iters + 1; ++k) { + // Maintain the heap + cur_task_->compute_dag.InferBound(pnow); + PruneUndefined(pnow); + cost_model->Predict(cur_task_, *pnow, &scores); + + for (size_t i = 0; i < pnow->size(); ++i) { + const State& state = (*pnow)[i]; + std::string state_str = state.ToStr(); + + if (in_heap.count(state_str) == 0) { + if (static_cast(heap.size()) < num_best_states) { + heap.emplace_back((*pnow)[i], scores[i]); + std::push_heap(heap.begin(), heap.end(), cmp); + in_heap.insert(state_str); + } else if (scores[i] > heap.front().second) { + std::string old_state_str = heap.front().first.ToStr(); + in_heap.erase(old_state_str); + in_heap.insert(state_str); + + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.back() = StateItem(state, scores[i]); + std::push_heap(heap.begin(), heap.end(), cmp); + } + if (scores[i] > max_score) { + max_score = scores[i]; + } + } + } + + if (k % 5 == 0 || k == num_iters) { + StdCout(verbose_) << "GA Iter: " << k << std::fixed << std::setprecision(4) + << "\tMax score: " << max_score + << "\tMin score: " << heap.front().second + << "\tPop size: " << pnow->size() << std::endl; + } + + if (k == num_iters) { + break; + } + + // Compute selection probability + double sum = 0.0; + prefix_sum_probs.resize(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + sum += std::max(scores[i], 0.0f); + prefix_sum_probs[i] = sum; + } + for (size_t i = 0; i < scores.size(); ++i) { + prefix_sum_probs[i] = prefix_sum_probs[i] / sum; + } + + // Do cross over + int ct = 0; + while (static_cast(pnext->size()) < num_cross_over + && ct < num_cross_over_trial_upper_bound) { + int p1 = RandomChoose(prefix_sum_probs, &rand_gen_); + int p2 = RandomChoose(prefix_sum_probs, &rand_gen_); + + if (p1 == p2) { + pnext->push_back((*pnow)[p1]); + } else { + State tmp_s = CrossOverState((*pnow)[p1], (*pnow)[p2]); + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } + } + ct++; + } + + // Do mutation + mutation_fail_ct = 0; + while (static_cast(pnext->size()) < population) { + int id = RandomChoose(prefix_sum_probs, &rand_gen_); + + if (dis(rand_gen_) < mutation_prob) { + const std::vector rule_prefix_sum_probs{0.9, 0.95, 1.0}; + + int rule_id = RandomChoose(rule_prefix_sum_probs, &rand_gen_); + + State tmp_s; + if (rule_id == 0) { + tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, + cur_task_->hardware_params->max_innermost_split_factor); + } else if (rule_id == 1) { + tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); + } else if (rule_id == 2) { + tmp_s = MutataParallel((*pnow)[id], &split_memo_, &rand_gen_, cur_task_); + } + + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } else { + mutation_fail_ct++; + } + } else { + pnext->push_back((*pnow)[id]); + } + } + + std::swap(pnext, pnow); pnext->clear(); + } + + // Copy best states in the heap to out_states + std::sort(heap.begin(), heap.end(), cmp); + best_states->clear(); + for (auto& item : heap) { + best_states->push_back(std::move(item.first)); + } + + double duration = std::chrono::duration_cast >( + std::chrono::high_resolution_clock::now()- tic_begin).count(); + StdCout(verbose_) << "EvolutionarySearch\t\t#s: " << best_states->size() + << "\tTime elapsed: " + << std::fixed << std::setprecision(2) << duration << std::endl; +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h new file mode 100644 index 000000000000..56a75f8e52fe --- /dev/null +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -0,0 +1,101 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/meta_tile_rewrite_policy.h + * \brief A search policy that search with meta tiling structure and random rewrite + */ +#ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ + +#include +#include +#include +#include +#include +#include "search_policy.h" +#include "../cost_model/cost_model.h" +#include "../utils.h" + + +namespace tvm { +namespace ansor { + +/*! Multi stage search policy */ +class MetaTileRewritePolicyNode: public SearchPolicyNode { + public: + CostModel program_cost_model; + + /* this->params is used to store the following arguments + * int evolutionary_search_population // The population size for evolutionary search + * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search + * int evolutionary_search_num_iters; // The number of iterations for evolutionary search + * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial + * // population for evolutionary search + * double eps_greedy; // Always allocate this percentage of measurements to random sampled states + * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU + * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU + */ + Map params; + + static SearchPolicy make(CostModel program_cost_model, + Map params, + int seed); + + // Search and make n_trails measurements + // Return the best state + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) final; + + // Continue search. This is used by JointTuner + std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + + static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static const std::vector auto_unroll_configs; + + TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); + + SearchTask cur_task_; // The current task + + friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT + protected: + // Pick states from best states and random states with eps-greedy policy + void PickStatesWithEpsGreedy(std::vector* inputs, + const std::vector& best_states, + const std::vector& random_states, int remaining_n_trials); + + private: + // Run one round of the search pipeline + void SearchOneRound(std::vector* best_states, + int num_random_states, std::vector* random_states); + + // Synthesize meta tiling structure without tile size + void SynthesizeMetaStructure(std::vector* out_states); + + // Sample init population + void SampleInitPopulation(const std::vector& meta_structures, + int out_size, std::vector* out_states); + + // Perform evolutionary search + void EvolutionarySearch(const std::vector& init_population, + int num_best_states, std::vector* best_states); + + SplitFactorizationMemo split_memo_; // Memorize split space for Split + std::mt19937 rand_gen_; // Random generator + int verbose_; // Verbose level (0 means silent) + int num_measure_per_iter_; // The number of states to measure per iteration + + // The set of the already measured states. We store the string format for redundancy check + std::unordered_set measured_states_set_; + + // The array of already measured states. + std::vector measured_states_vector_; + + // The throughputs of already measured states + std::vector measured_states_throughputs_; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc new file mode 100644 index 000000000000..89bfeb1a8edd --- /dev/null +++ b/src/ansor/search_policy/search_policy.cc @@ -0,0 +1,14 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); + +} // namespace ansor +} // namespace tvm + diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h new file mode 100644 index 000000000000..5bd9fb3118b1 --- /dev/null +++ b/src/ansor/search_policy/search_policy.h @@ -0,0 +1,53 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_policy.h + * \brief Base class of search policy + */ +#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ + +#include +#include +#include +#include +#include "../search_task.h" +#include "../measure.h" + +namespace tvm { +namespace ansor { + +class SearchPolicy; + +/*! \brief The base class for search policy */ +class SearchPolicyNode : public Object { + public: + virtual State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) = 0; + + virtual std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; + + // Dict keys + static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; + static constexpr const char* always_unroll_key = "ansor_always_unroll"; + static constexpr const char* no_split_at_inner_key = "ansor_no_split_at_inner"; + static constexpr const char* no_split_at_outer_key = "ansor_no_split_at_outer"; + static constexpr const char* debug_skip_region_key = "ansor_debug_skip_region"; + static constexpr const char* last_split_is_one_key = "ansor_last_split_is_one"; + + // Flag keys + static constexpr const char* always_compute_inline_key = "ansor_always_compute_inline"; + static constexpr const char* no_cache_write_key = "ansor_no_cache_write"; + static constexpr const char* no_cache_read_key = "ansor_no_cache_read"; + static constexpr const char* tensor_core_support_key = "ansor_tensor_core_support"; + + static constexpr const char *_type_key = "ansor.SearchPolicy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(SearchPolicy, SearchPolicyNode); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc new file mode 100644 index 000000000000..9c597b4eb811 --- /dev/null +++ b/src/ansor/search_policy/utils.cc @@ -0,0 +1,609 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "utils.h" +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids) { + auto pop = s->stages[stage_id]->op.as(); + CHECK(pop != nullptr); + + auto no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); + std::set no_split_at_inner_name_set = no_split_name_pair.first; + std::set no_split_at_outer_name_set = no_split_name_pair.second; + size_t reduce_count = 0; + for (const auto axis : pop->reduce_axis) { + if (!no_split_at_inner_name_set.count(axis->var->name_hint) && + !no_split_at_outer_name_set.count(axis->var->name_hint)) { + reduce_count++; + } + } + + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id > s->transform_steps[i]->stage_id) { + stage_id--; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id == ps->stage_id) { + if (reduce_count) { + reduce_count--; + } else { + spatial_split_step_ids->push_back(i); + } + } + } + } +} + +// Query axes that should not be splitted according to the attribute from tvm.compute +std::pair, std::set > QueryNoSplitAxis(const Stage& stage) { + std::pair, std::set > ret; + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { + ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); + } + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { + ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); + } + return ret; +} + +// Query axes that last split is one +std::set QueryLastSplitIsOneAxis(const Stage& stage) { + std::set ret; + if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { + ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); + } + return ret; +} + +// Apply multi-tiling structure according to a string format +State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, + std::vector* spatial_split_step_ids) { + std::vector > space_levels; + std::vector > reduce_levels; + std::vector space_outer, space_inner, reduce_outer, reduce_inner; + std::vector split_res; + + for (const auto c : format) { + if (tolower(c) == 's') { + space_levels.emplace_back(); + } else if (tolower(c) == 'r') { + reduce_levels.emplace_back(); + } else { + LOG(FATAL) << "Invalid multi level tiling format: " << format; + } + } + size_t n_space = space_levels.size(); + size_t n_reduce = reduce_levels.size(); + + spatial_split_step_ids->clear(); + + State tmp_s = state; + const Stage& stage = state->stages[stage_id]; + auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + auto last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); + std::set no_split_at_inner_name_set = no_split_name_pair.first; + std::set no_split_at_outer_name_set = no_split_name_pair.second; + + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_type == kSpace) { + if (!no_split_at_inner_name_set.count(iter->name) && + !no_split_at_outer_name_set.count(iter->name)) { + CHECK_GE(n_space, 1); + int tmp_n_space = n_space; + + if (last_split_is_one_name_set.count(iter->name)) { + tmp_n_space--; + } + + if (tmp_n_space == 1) { + space_levels[0].push_back(iter); + } else { + split_res = tmp_s.split(stage_id, iter, std::vector(tmp_n_space - 1)); + for (int i = 0; i < tmp_n_space; i++) { + space_levels[i].push_back(std::move(split_res[i])); + } + spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); + } + } else { + if (no_split_at_inner_name_set.count(iter->name)) { + space_inner.push_back(iter); + } + if (no_split_at_outer_name_set.count(iter->name)) { + space_outer.push_back(iter); + } + } + } else if (iter->iter_type == kReduce) { + // for reduce iterator, split it into two iterators + if (!no_split_at_inner_name_set.count(iter->name) && + !no_split_at_outer_name_set.count(iter->name)) { + CHECK_GE(n_reduce, 1); + if (n_reduce == 1) { + reduce_levels[0].push_back(iter); + } else { + split_res = tmp_s.split(stage_id, iter, std::vector(n_reduce - 1)); + for (size_t i = 0; i < n_reduce; i++) { + reduce_levels[i].push_back(std::move(split_res[i])); + } + } + } else { + if (no_split_at_inner_name_set.count(iter->name)) { + reduce_inner.push_back(iter); + } + if (no_split_at_outer_name_set.count(iter->name)) { + reduce_outer.push_back(iter); + } + } + } else { + LOG(FATAL) << "Invalid iter type: " << iter->iter_type; + } + } + + if (!space_outer.empty()) { + CHECK(!space_levels.empty()); + space_levels.front().insert(space_levels.front().begin(), + space_outer.begin(), space_outer.end()); + } + if (!space_inner.empty()) { + CHECK(!space_levels.empty()); + space_levels.back().insert(space_levels.back().begin(), + space_inner.begin(), space_inner.end()); + } + + if (!reduce_outer.empty()) { + CHECK(!reduce_levels.empty()); + reduce_levels.front().insert(reduce_levels.front().begin(), + reduce_outer.begin(), reduce_outer.end()); + } + if (!reduce_inner.empty()) { + CHECK(!reduce_levels.empty()); + reduce_levels.back().insert(reduce_levels.back().begin(), + reduce_inner.begin(), reduce_inner.end()); + } + + std::vector order; + int space_ct = 0, reduce_ct = 0; + for (const auto c : format) { + if (tolower(c) == 's') { + order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), + std::make_move_iterator(space_levels[space_ct].end())); + space_ct++; + } else if (tolower(c) == 'r') { + order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), + std::make_move_iterator(reduce_levels[reduce_ct].end())); + reduce_ct++; + } else { + LOG(FATAL) << "Invalid multi level tiling format: " << format; + } + } + + tmp_s.reorder(stage_id, order); + return tmp_s; +} + +// Apply tiling structure: space, space +// But use tile sizes from other SplitStep +State FollowTiling(const State& state, int stage_id, + const std::vector& split_step_ids, int n_split) { + if (n_split < 1 || n_split > 3) { + LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3"; + } + // Apply up to three-level tiling structure: space_L0, space_L1, space_L2 + std::vector space_0, space_1, space_2, space_3; + std::vector split_res, tmp_order; + + auto pop = state->stages[stage_id]->op.as(); + CHECK(pop != nullptr); + const Stage& stage = state->stages[stage_id]; + auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + const std::set& no_split_at_inner_name_set = no_split_name_pair.first; + const std::set& no_split_at_outer_name_set = no_split_name_pair.second; + int no_split_at_inner_name_in_stage_cnt = 0; + int no_split_at_outer_name_in_stage_cnt = 0; + for (const auto& iter : state->stages[stage_id]->iters) { + no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name); + no_split_at_outer_name_in_stage_cnt += no_split_at_outer_name_set.count(iter->name); + } + + CHECK_EQ(state->stages[stage_id]->iters.size() + - no_split_at_inner_name_in_stage_cnt + - no_split_at_outer_name_in_stage_cnt, + split_step_ids.size()); + + State tmp_s = state; + int ct = 0; + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_type == kSpace) { + // For spatial iterator, split it into multi iterators + if (!no_split_at_inner_name_set.count(iter->name) && + !no_split_at_outer_name_set.count(iter->name)) { + IteratorAnnotation ann_type = iter->annotation; + split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], + n_split); + // Restore annotation. Move unroll and vectorize to inner, move parallel + // to outer + switch (ann_type) { + case kUnroll: + split_res[n_split] = tmp_s.unroll(stage_id, split_res[n_split]); + break; + case kVectorize: + split_res[n_split] = tmp_s.vectorize(stage_id, split_res[n_split]); + break; + case kParallel: + split_res[0] = tmp_s.parallel(stage_id, split_res[0]); break; + default: + break; + } + + space_0.push_back(std::move(split_res[0])); + space_1.push_back(std::move(split_res[1])); + if (n_split >= 2) { + space_2.push_back(std::move(split_res[2])); + if (n_split == 3) { + space_3.push_back(std::move(split_res[3])); + } + } + ct++; + } else { + if (no_split_at_outer_name_set.count(iter->name)) { + space_0.push_back(iter); + } + if (no_split_at_inner_name_set.count(iter->name)) { + if (n_split == 1) { + space_1.push_back(iter); + } else if (n_split == 2) { + space_2.push_back(iter); + } else { + CHECK_EQ(n_split, 3); + space_3.push_back(iter); + } + } + } + } else { + LOG(FATAL) << "Invalid iter type: " << iter->iter_type; + } + } + if (n_split == 3) { + ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); + } else if (n_split == 2) { + ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2); + } else { + ConcatenateMove(&tmp_order, &space_0, &space_1); + } + tmp_s.reorder(stage_id, tmp_order); + return tmp_s; +} + +// Randomly mutate the tile size of one SplitStep +State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, int max_innermost_split_factor) { + State tmp_s = old_state; + + // Extract all SplitStep + std::vector split_step_ids; + for (size_t i = 0; i < tmp_s->transform_steps.size(); ++i) { + if (auto ps = tmp_s->transform_steps[i].as()) { + if (ps->extent.defined() && ps->extent->IsInstance() && + GetIntImm(ps->lengths.back()) <= max_innermost_split_factor) { + split_step_ids.push_back(i); + } + } + } + if (split_step_ids.empty()) { + return State(); + } + + // Find a SplitStep with extent != 1 + int retry_ct = 0; + int64_t extent = 1; + int step_id; + const SplitStepNode* ps; + + do { + step_id = split_step_ids[(*random_gen)() % split_step_ids.size()]; + ps = tmp_s->transform_steps[step_id].as(); + CHECK(ps != nullptr); + extent = GetIntImm(ps->extent); + retry_ct += 1; + } while (retry_ct < static_cast(split_step_ids.size()) << 2 && extent == 1); + + if (extent == 1) { + return State(); + } + + // Mutate tile size + std::vector lengths(ps->lengths.size() + 1, 1); + for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { + lengths[i + 1] = GetIntImm(ps->lengths[i]); + } + lengths[0] = extent / ElementProduct(lengths); + + std::vector random_perm; + RandomPermutation(lengths.size(), &random_perm, random_gen); + + for (size_t i = 0; i < random_perm.size(); ++i) { + size_t src_idx = random_perm[i]; + int length = lengths[src_idx]; + + if (length == 1) { + continue; + } + + // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] + size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; + + const std::vector& factors = split_memo->GetFactors(length); + CHECK_GE(factors.size(), 1); + + int divide_factor; + if (dst_idx == lengths.size() - 1) { + // Maintain the restriction of hardware_params.max_innermost_split_factor + int max_factor_index = static_cast(factors.size()) - 1; + for (; max_factor_index >= 1; max_factor_index--) { + if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { + break; + } + } + if (max_factor_index == 0) { + // failed on this dst_idx, try next one + continue; + } + divide_factor = factors[1 + (*random_gen)() % (max_factor_index)]; + } else { + divide_factor = factors[1 + (*random_gen)() % (factors.size() - 1)]; + } + + std::vector new_lengths; + for (size_t j = 1; j < lengths.size(); ++j) { + if (j == src_idx) { + new_lengths.emplace_back(lengths[j] / divide_factor); + } else if (j == dst_idx) { + new_lengths.emplace_back(lengths[j] * divide_factor); + } else { + new_lengths.emplace_back(lengths[j]); + } + } + + CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); + + auto pstate = tmp_s.CopyOnWrite(); + pstate->transform_steps[step_id] = + SplitStepNode::make(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); + return tmp_s; + } + + return State(); +} + +// Randomly mutate the value of one auto_unroll_max_step PragmaStep +State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, + const std::vector& auto_unroll_configs) { + State tmp_s = old_state; + + // Extract all auto_unroll_max_step pragma steps. + std::vector annotate_steps; + for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { + if (auto ps = tmp_s->transform_steps[i].as()) { + if (ps->pragma_type.find("auto_unroll_max_step") != std::string::npos) { + annotate_steps.push_back(i); + } + } + } + if (annotate_steps.empty()) { + return State(); + } + + // Randomly pick one step. + auto step_id = annotate_steps[(*random_gen)() % annotate_steps.size()]; + auto ps = tmp_s->transform_steps[step_id].as(); + auto val = std::to_string(auto_unroll_configs[(*random_gen)() % auto_unroll_configs.size()]); + + auto pstate = tmp_s.CopyOnWrite(); + pstate->transform_steps[step_id] = PragmaStepNode::make( + ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val); + return tmp_s; +} + +// Mutate a parallel loop. +State MutataParallel(const State& state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, SearchTask& task, int verbose) { + // To make this mutation simple but promising, we only focus on a specific case that + // parallel was added to the outermost loop and the loop is generated by fusing other loops. + // In short, we mutate the step pattern of (fuse -> parallel). + + // Extract all parallel steps. + std::vector parallel_steps; + for (size_t s = 0; s < state->transform_steps.size(); ++s) { + auto ps = state->transform_steps[s].as(); + if (!ps || ps->annotation != kParallel) { + continue; + } + parallel_steps.push_back(s); + } + if (parallel_steps.size() == 0) { + StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; + return State(); + } + + // Randomly pick one step. + int retry_ct = 0; + size_t step_id = 0; + size_t stage_id = 0; + do { + step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; + auto step = state->transform_steps[step_id].as(); + stage_id = step->stage_id; + + // Check assumptions. + auto iter_id = step->iter_id; + if (iter_id == 0 && step_id > 0 && state->transform_steps[step_id - 1].as()) { + break; + } + retry_ct++; + } while (retry_ct <= 3); + + if (retry_ct > 3) { + StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; + return State(); + } + + // 0: fuse less; 1: fuse more. + std::vector fuse_dir = {0.5, 1.0}; + + // The iter is an attached target so we can only fuse less. + if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0) { + fuse_dir[0] = 1.0; + } + + // Determine the fuse direction. + auto fuse_step = state->transform_steps[step_id - 1].as(); + std::vector fused_ids = fuse_step->fused_ids; + int iter_offset = 0; + if (RandomChoose(fuse_dir, random_gen) == 0) { + StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; + fused_ids.pop_back(); + iter_offset = 1; + } else { + StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; + fused_ids.push_back(fused_ids.back() + 1); + iter_offset = -1; + } + + // Replay a new state. + State tmp_s = task->compute_dag.GetInitState(); + for (size_t s = 0; s < state->transform_steps.size(); ++s) { + auto step = state->transform_steps[s]; + if (s == step_id - 1) { + step = FuseStepNode::make(step->stage_id, fused_ids); + } else if (s > step_id && step->stage_id == static_cast(stage_id)) { + // Since we change the loop structure, iter ID in later steps to the same stage + // has to be adjusted. + auto ps = step.as(); + if (ps) { + CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); + step = AnnotationStepNode::make(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); + } else { + StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" + << std::endl; + return State(); + } + } + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + tmp_s.DoStep(step, task->compute_dag); + } + return state; +} + +// Create all possible tile size states for all SplitStep +void GridMutateTileSize(const State& old_state, std::vector* cands, + SplitFactorizationMemo* split_memo, int max_innermost_split_factor) { + // Extract all SplitStep. + std::vector split_step_ids; + for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { + if (old_state->transform_steps[i]->IsInstance()) { + split_step_ids.push_back(i); + } + } + if (split_step_ids.empty()) { + return; + } + + // Move tile sizes and generate candidates. + for (size_t step_id : split_step_ids) { + const SplitStepNode* ps = old_state->transform_steps[step_id].as(); + CHECK(ps != nullptr); + + int extent = GetIntImm(ps->extent); + if (extent == 1) { + continue; + } + + // Get the current tile sizes. + std::vector lengths(ps->lengths.size(), 1); + for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { + lengths[i] = GetIntImm(ps->lengths[i]); + } + + const std::vector& const_factors = split_memo->GetFactors(extent); + CHECK_GE(const_factors.size(), 1); + + // Move tile size. + for (size_t i = 0; i < ps->lengths.size(); ++i) { + int old_length = lengths[i]; + + for (int factor : const_factors) { + if (i == ps->lengths.size() - 1 && factor > max_innermost_split_factor) { + // Limit the innermost factor. + break; + } + + // Make new length experssions and a new state. + std::vector length_exprs; + lengths[i] = factor; + int outermost = extent / ElementProduct(lengths); + if (outermost == 0) { + break; + } + + // std::cout << "Mutated extent " << extent << ": " << outermost; + for (size_t j = 0; j < lengths.size(); ++j) { + // std::cout << ", " << lengths[j]; + length_exprs.emplace_back(lengths[j]); + } + // std::cout << std::endl; + + State tmp_s = old_state; + const SplitStepNode* new_ps = tmp_s->transform_steps[step_id].as(); + auto pstate = tmp_s.CopyOnWrite(); + pstate->transform_steps[step_id] = + SplitStepNode::make(new_ps->stage_id, new_ps->iter_id, new_ps->extent, length_exprs, + new_ps->inner_to_outer); + if (tmp_s.defined()) { + cands->push_back(std::move(tmp_s)); + } + } + lengths[i] = old_length; + } + } +} + +// Random choose an index according to a prefix sum probability +int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); + + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} + +// Prune undefined states. +void PruneUndefined(std::vector* states) { + size_t pt = 0; + for (size_t i = 0; i < states->size(); ++i) { + if (!(*states)[i].defined()) { + continue; + } + (*states)[pt++] = std::move((*states)[i]); + } + + if (pt == 0) { + LOG(FATAL) << "All states are undefined."; + } else { + states->resize(pt); + } +} + +State CrossOverState(const State& p1, const State& p2) { return State(); } + +} // namespace ansor +} // namespace tvm + diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h new file mode 100644 index 000000000000..05b50775b52d --- /dev/null +++ b/src/ansor/search_policy/utils.h @@ -0,0 +1,428 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/utils.h + * \brief Common utilities for local mutation in search policy + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_UTILS_H_ +#define TVM_ANSOR_SEARCH_POLICY_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include "../cost_model/cost_model.h" +#include "../utils.h" +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +inline bool StringEndWith(const std::string& str, const std::string& target) { + int str_len = str.length(); + int target_len = target.length(); + if (str_len <= target_len) { + return false; + } + return str.compare(str_len - target_len, target_len, target) == 0; +} + +// Get an integer from a tvm str Map +inline int GetIntParam(const Map& attr_dict, + const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pint = attr_dict[key].as(); + CHECK(pint != nullptr); + return pint->value; +} + +// Get a double from a tvm str Map +inline double GetDoubleParam(const Map& attr_dict, + const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pdouble = attr_dict[key].as(); + CHECK(pdouble != nullptr); + return pdouble->value; +} + +// Get a string from a tvm str Map +inline std::string GetStringParam(const Map& attr_dict, + const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pstr = attr_dict[key].as(); + CHECK(pstr != nullptr); + return pstr->value; +} + +// Get a iterator name set from a tvm str Map +inline std::set GetIterNameSetParam(const Map& attr_dict, + const std::string& key) { + std::set ret; + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto names = attr_dict[key].as(); + CHECK(names != nullptr); + for (auto name = names->begin(); name != names->end(); name++) { + ret.insert(name->as()->value); + } + return ret; +} + +// Convert operation to stage id +inline int OperationToStage(const te::Operation& op, const State& state) { + for (size_t i = 0; i < state->stages.size(); ++i) { + if (op == state->stages[i]->op) { + return i; + } + } + LOG(FATAL) << "Cannot find op: " << op; + return -1; +} + +// Return the extent of an iterator +inline int64_t GetExtent(const Iterator& it) { + if (it->range.defined()) { + if (auto pint = it->range->extent.as()) { + return pint->value; + } + } + return -1; +} + +// Return whether an op is strict inlineable +inline bool IsStrictInlineable(const SearchTask& task, const State& state, const te::Operation& op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.IsStrictInlineable(op); + } else { + return task->compute_dag->access_analyzer.IsStrictInlineable(op); + } +} + +// Return whether an op is an output op +inline bool IsOutputOp(const SearchTask& task, const State& state, const te::Operation& op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.IsOutput(op); + } else { + return task->compute_dag->access_analyzer.IsOutput(op); + } +} + +// Return whether the stage has an attribute flag +inline bool HasAttrsFlag(const State& state, int stage_id, const char* target) { + if (state->stages[stage_id]->op->attrs.count(target)) { + return GetStringParam(state->stages[stage_id]->op->attrs, target) == "True"; + } + return false; +} + +// Return whether the stage has reduce iterators +inline bool HasReduceIter(const Stage& stage) { + for (const auto& iter : stage->iters) { + if (iter->iter_type != kSpace) { + return true; + } + } + return false; +} + +// Return whether an op needs multi level tiling +inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.NeedsMultiLevelTiling(op); + } else { + return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(op); + } +} + +// Get all consumers for an op. This will take inline into consideration +inline void GetConsumers(const SearchTask& task, const State& state, const te::Operation& op, + std::unordered_set* consumers) { + if (state->task_dag.defined()) { + state->task_dag->access_analyzer.GetConsumers(state, op, consumers); + } else { + task->compute_dag->access_analyzer.GetConsumers(state, op, consumers); + } +} + +inline void GetProducers(const SearchTask& task, const State& state, const te::Operation& op, + std::unordered_set* producers) { + if (state->task_dag.defined()) { + state->task_dag->access_analyzer.GetProducers(state, op, producers); + } else { + task->compute_dag->access_analyzer.GetProducers(state, op, producers); + } +} + +// Return whether two ops are elementwise-matched +inline bool ElementwiseMatch(const SearchTask& task, const State& state, const te::Operation& op, + const te::Operation& target_op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.ElementWiseMatch(op, target_op); + } else { + return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); + } +} + +// Return whether the stage has only one consumer and they are elementwise-matched +inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, + const State& state, const Stage& stage, + int* target_stage_id) { + std::unordered_set consumers; + + GetConsumers(task, state, stage->op, &consumers); + if (consumers.size() == 1) { + *target_stage_id = OperationToStage(*consumers.begin(), state); + const Stage& target_stage = state->stages[*target_stage_id]; + if (ElementwiseMatch(task, state, stage->op, target_stage->op) && + (!(HasReduceIter(stage) && HasReduceIter(target_stage)))) { + return true; + } + } + return false; +} + +// Return whether this stage needs rfactor +inline bool NeedsRfactor(const SearchTask& task, const State& state, const te::Operation& op) { + if (op->IsInstance()) { + // Compute the product of lengths of all space iters and all reduce iters + int64_t cum_space_len = 1, cum_reduce_len = 1; + int stage_id = OperationToStage(op, state); + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_type == kSpace) { + cum_space_len *= GetExtent(iter); + } else if (iter->iter_type == kReduce) { + cum_reduce_len *= GetExtent(iter); + } + } + + if (NeedsMultilevelTiling(task, state, op)) { + // Do not use rfactor if we have enough parallelism on space iters + if (cum_space_len > cum_reduce_len + || cum_space_len > task->hardware_params->num_cores * 16) { + return false; + } else { + return true; + } + } else if (cum_reduce_len > 1) { + // Always try rfactor for reduction ops + return true; + } + } + + return false; +} + +// Return whether the state did cache_write for stage_id +inline bool HasCacheWriteStage(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } else if (stage_id == ps->stage_id) { + return true; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } + } + return false; +} + +inline bool HasCacheReadStage(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } else if (stage_id == ps->stage_id) { + return true; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } + } + return false; +} + +void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); + +inline bool HasSplitStep(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id > s->transform_steps[i]->stage_id) { + stage_id--; + } + } else if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id == s->transform_steps[i]->stage_id) { + return true; + } + } + } + return false; +} + +// Return whether the stage has been tiled already +inline bool IsTiled(const Stage& stage) { + auto op = stage->op.as(); + CHECK(op != nullptr); + return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); +} + +// Query axes that should not be splitted according to the attribute from tvm.compute +std::pair, std::set > QueryNoSplitAxis(const Stage& stage); +// Query axes that last split is one +std::set QueryLastSplitIsOneAxis(const Stage& stage); + +// Extract primitive iterators from a nested fused or splitted iterator's name +inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { + size_t last_pos = 0; + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split + if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { + rets->insert(name.substr(last_pos, i - last_pos)); + } + last_pos = i + 1; + } + } + + if (last_pos < name.size() && !isdigit(name[last_pos]) && + name[last_pos] != '@' && name[last_pos] != '.') { + rets->insert(name.substr(last_pos, name.size() - last_pos)); + } +} + +// Get the last space iterator in the outer most tile +inline const Iterator& GetLastSpaceIteratorInOutermostTile(const Stage& stage) { + auto pop = stage->op.as(); + CHECK(pop != nullptr); + std::set original_names; + + for (const auto& iter : stage->iters) { + ExtractOriginalIterators(iter->name, &original_names); + if (original_names.size() == pop->axis.size()) { + return iter; + } + } + + LOG(FATAL) << "Cannot find the iterator."; + return stage->iters[0]; +} + +inline const Iterator& GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { + auto pop = stage->op.as(); + CHECK(pop != nullptr); + std::set original_names; + + auto no_split_name_pair = QueryNoSplitAxis(stage); + std::set no_split_at_inner_name_set = no_split_name_pair.first; + size_t axis_size = 0; + for (const auto axis : pop->axis) { + if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { + axis_size++; + } + } + size_t reduce_axis_size = 0; + for (const auto axis : pop->reduce_axis) { + if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { + reduce_axis_size++; + } + } + + if (reduce_axis_size) { + for (const auto& iter : stage->iters) { + ExtractOriginalIterators(iter->name, &original_names); + if (original_names.size() == axis_size + reduce_axis_size) { + return iter; + } + } + } else { + for (size_t i = 0; i < stage->iters.size(); i++) { + ExtractOriginalIterators(stage->iters[i]->name, &original_names); + if (original_names.size() == axis_size + 1) { + return stage->iters[i-1]; + } + } + } + + LOG(FATAL) << "Cannot find the iterator."; + return stage->iters[0]; +} + +// Random sample states +inline void RandomSampleStates(const std::vector& in_states, std::mt19937* random_gen, + size_t out_size, std::vector* out_states) { + out_states->clear(); + for (size_t i = 0; i < out_size; i++) { + out_states->push_back(in_states[(*random_gen)() % in_states.size()]); + } +} + +// Random choose an index according to a prefix sum probability +int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen); + +// Prune undefined states. +void PruneUndefined(std::vector* states); + +// Print all states +inline void PrintAllStates(const std::vector& states) { + for (size_t i = 0; i < states.size(); ++i) { + std::cerr << i << std::endl; + std::cerr << states[i]; + std::cerr << "==============================================" << std::endl; + } +} + +// Apply multi-level tiling structure according to a string format, +// where "S" stands a space level, "R" stands for a reudciton level. +// For example, if the format is "SSRSRS", the we will +// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 +// For example, if apply "SSRSRS" to matrix multiplication, +// we have space iterators i and j, reduce iterator k. +// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 +State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, + std::vector* spatial_split_step_ids); + +// Apply tiling structure: space, space +// But use tile sizes from other SplitStep +State FollowTiling(const State& state, int stage_id, + const std::vector& split_step_ids, int n_split); + +// Randomly mutate the tile size of one SplitStep +State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, int max_innermost_split_factor); + +// Randomly mutate the value of one auto_unroll_max_step PragmaStep +State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, + const std::vector& auto_unroll_configs); + +// Mutate a parallel loop. +State MutataParallel(const State& old_state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, SearchTask& task, int verbose = 0); + +// Create all possible tile size states for all SplitStep +void GridMutateTileSize(const State& old_state, std::vector* cands, + SplitFactorizationMemo* split_memo, int max_innermost_split_factor); + +// GA: Crossover two states +State CrossOverState(const State& p1, const State& p2); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_UTILS_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index c43ec5c0a751..bbcef05f31fc 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -24,6 +24,8 @@ #include #include "../../src/ansor/loop_state.h" #include "../../src/ansor/serialization.h" +#include "../../src/ansor/feature.h" +#include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -84,11 +86,13 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, return {data, kernel, bias, bn_scale, bn_offset, out}; } +using namespace tvm::ansor; + TEST(ComputeDAG, Basic) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); - const auto& state = tvm::ansor::StateNode::make(dag->ops); - CHECK(std::equal_to()(state, dag.GetInitState())); + const auto& dag = ComputeDAGNode::make(tensors); + const auto& state = StateNode::make(dag->ops); + CHECK(std::equal_to()(state, dag.GetInitState())); LOG(INFO) << "\n" << state; LOG(INFO) << "\n" << dag; @@ -96,8 +100,6 @@ TEST(ComputeDAG, Basic) { } TEST(ComputeDAG, GetProducersConsumers) { - using namespace tvm::ansor; - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; @@ -159,8 +161,6 @@ TEST(ComputeDAG, GetProducersConsumers) { } TEST(ComputeDAG, InferBoundSerialization) { - using namespace tvm::ansor; - const auto& tensors = matmul_func(512, 512, 512); const auto& dag = ComputeDAGNode::make(tensors); int A = 0, B = 1, C = 2; @@ -216,8 +216,6 @@ TEST(ComputeDAG, InferBoundSerialization) { } TEST(Step, SplitFuseReorder) { - using namespace tvm::ansor; - const auto& tensors = matmul_func(512, 512, 512); const auto& dag = ComputeDAGNode::make(tensors); @@ -257,8 +255,6 @@ TEST(Step, SplitFuseReorder) { } TEST(Step, ComputeAtRootInline) { - using namespace tvm::ansor; - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); // int data = 0, padding = 1, kernel = 2; @@ -334,7 +330,6 @@ TEST(Step, ComputeAtRootInline) { TEST(Step, CacheReadWrite) { using namespace tvm; using namespace tvm::te; - using namespace tvm::ansor; const auto& test_func = []() -> Array { int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; @@ -591,8 +586,6 @@ TEST(Step, CacheReadWrite) { } TEST(Step, FollowSplitFollowFusedSplit) { - using namespace tvm::ansor; - const auto& tensors = matmul_func(512, 512, 512); const auto& dag = ComputeDAGNode::make(tensors); @@ -660,6 +653,84 @@ TEST(Step, Rfactor) { // todo } +TEST(Feature, ExtractionMatmul) { + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + State s0 = dag.GetInitState(); + + Iterator ti = s0->stages[2]->iters[0]; + Iterator tj = s0->stages[2]->iters[1]; + Iterator tk = s0->stages[2]->iters[2]; + std::vector its; + its = s0.split(2, ti, {16}); + Iterator tio = its[0], tii = its[1]; + its = s0.split(2, tj, {8}); + Iterator tjo = its[0], tji = its[1]; + s0.reorder(2, {tio, tjo, tk, tji, tii}); + s0.vectorize(2, tji); + s0.parallel(2, tio); + s0.parallel(2, tjo); + s0.unroll(2, tk); + + int max_n_bufs = 5; + std::vector> features; + std::vector feature_names; + GetPerStmtFeatureName(max_n_bufs, &feature_names); + GetPerStmtFeaturesFromStates({s0}, + SearchTaskNode::make(dag, "test", tvm::target::llvm(), + tvm::target::llvm(), + HardwareParams()), + max_n_bufs, 0, &features); + int num_states = 1; + CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); + // TODO(...): Add feature check here +} + +namespace tvm { +namespace ansor { +class MetaTileRewritePolicyNodeTest { + public: + MetaTileRewritePolicyNodeTest(CostModel cost_model, SearchTask task) { + policy = make_object(); + policy->program_cost_model = std::move(cost_model); + policy->rand_gen_ = std::mt19937(0); + policy->params.Set("cpu_multi_level_tiling_structure", + te::StringImmNode::make("SSRSRS")); + policy->params.Set("disable_change_compute_location", + IntImm(DataType::Int(32), 0)); + policy->cur_task_ = task; + } + void SynthesizeMetaStructure(std::vector* meta_structures) { + policy->SynthesizeMetaStructure(meta_structures); + } + void SampleInitPopulation(const std::vector& meta_structures, + int out_size, std::vector* out_states) { + policy->SampleInitPopulation(meta_structures, out_size, out_states); + } + tvm::runtime::ObjectPtr policy; +}; +} // namespace ansor +} // namespace tvm + +TEST(MetaTileRewritePolicy, Basic) { + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + const auto& task = SearchTaskNode::make( + dag, "test", tvm::target::llvm(), tvm::target::llvm(), HardwareParams()); + const auto& cost_model = RandomModelNode::make(); + MetaTileRewritePolicyNodeTest test(cost_model, task); + + std::vector meta_structures, init_population; + test.SynthesizeMetaStructure(&meta_structures); + CHECK_GE(meta_structures.size(), 0); + LOG(INFO) << "SynthesizeMetaStructure get " << meta_structures.size() + << " states."; + test.SampleInitPopulation(meta_structures, 100, &init_population); + CHECK_GE(init_population.size(), 0); + LOG(INFO) << "SampleInitPopulation get " << init_population.size() + << " states."; +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; From 359905a0dd2b161dab662ebf7a7ae911812ee29b Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 3 Jun 2020 15:36:02 +0800 Subject: [PATCH 05/78] Basic Python API for State (#6) * Add Basic Python API for State * Add UTs for State --- python/tvm/ansor/__init__.py | 20 + python/tvm/ansor/_ffi_api.py | 21 + python/tvm/ansor/compute_dag.py | 34 ++ python/tvm/ansor/state.py | 387 +++++++++++++++++ src/ansor/compute_dag.cc | 2 + src/ansor/loop_state.cc | 149 +++++++ src/ansor/transform_step.cc | 1 + src/ansor/transform_step.h | 5 + tests/cpp/ansor_test.cc | 212 +++++---- tests/python/unittest/test_ansor_common.py | 475 +++++++++++++++++++++ 10 files changed, 1222 insertions(+), 84 deletions(-) create mode 100644 python/tvm/ansor/__init__.py create mode 100644 python/tvm/ansor/_ffi_api.py create mode 100644 python/tvm/ansor/compute_dag.py create mode 100644 python/tvm/ansor/state.py create mode 100644 tests/python/unittest/test_ansor_common.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py new file mode 100644 index 000000000000..aaa0e9c9174d --- /dev/null +++ b/python/tvm/ansor/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Ansor autoSchedule""" + +from .compute_dag import ComputeDAG diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py new file mode 100644 index 000000000000..177299e67d21 --- /dev/null +++ b/python/tvm/ansor/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.ansor""" +import tvm._ffi + + +tvm._ffi._init_api("ansor", __name__) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py new file mode 100644 index 000000000000..3c46440f75ba --- /dev/null +++ b/python/tvm/ansor/compute_dag.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from .state import State + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.ComputeDAG") +class ComputeDAG(Object): + def __init__(self, tensors): + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) + + def get_init_state(self) -> State: + return self.init_state diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py new file mode 100644 index 000000000000..9a8810190199 --- /dev/null +++ b/python/tvm/ansor/state.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + pass + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + + def iterator(self, index): + return _ffi_api.StageGetIterator(self, index) + + def iterators(self): + return _ffi_api.StageGetIterators(self) + + +@tvm._ffi.register_object("ansor.State") +class State(Object): + + def stage(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + stage : Stage + """ + return _ffi_api.StateGetStage(self, index) + + def transform_steps_size(self): + """ Return the size of transform_steps + """ + return _ffi_api.StateGetTransformStepsSize(self) + + def reorder(self, stage_id, order): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + order : List[Iterator] + Iterators in expected order + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateReorder(self, stage_id, order) + return state + + def split(self, stage_id, it, lengths, inner_to_outer=True): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + lengths: List[Int] + The split factor + inner_to_outer: Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateSplit(self, stage_id, it, lengths, + inner_to_outer) + return state + + def follow_split(self, stage_id, it, src_step_id, n_split): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + src_step_id : Int + The index of target step that this split follows + n_split : Int + Indecate how many level needs to be split out + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFollowSplit(self, stage_id, it, src_step_id, + n_split) + return state + + def follow_fused_split(self, stage_id, it, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + src_step_ids : List[Int] + The indexes of target step that this split follows + level : Int + factor_or_nparts : Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFollowFusedSplit(self, stage_id, it, src_step_ids, + level, factor_or_nparts) + return state + + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + iters : List[Iterator] + The target Iterators to be fused + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFuse(self, stage_id, iters) + return state + + def vectorize(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be vectorized + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateVectorize(self, stage_id, it) + return state + + def parallel(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be paralleled + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateParallel(self, stage_id, it) + return state + + def unroll(self, stage_id, it, max_unroll=-1): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be unrolled + max_unroll : Int + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) + return state + + def bind_thread(self, stage_id, it, thread_type): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be vectorized + thread_type : ... + Supported type: kVThread, kBlockX, kThreadX, kThreadY + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateBindThread(self, stage_id, it, thread_type) + return state + + def compute_at(self, stage_id, target_stage_id, target_iter): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + target_stage_id : Int + The index of compute at target stage + target_iter : Iterator + The target Iterator to be compute at + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeAt(self, stage_id, target_stage_id, + target_iter) + + def compute_root(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeRoot(self, stage_id) + + def compute_inline(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeInline(self, stage_id) + + def pack_for_vec(self, stage_id, target_iter, vec_size): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + target_iter : Iterator + The target Iterator + vec_size : Int + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StatePackForVec(self, stage_id, target_iter, vec_size) + + def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + scope_name : Str + reader_stage_ids : List[Int] + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateCacheRead(self, stage_id, scope_name, + reader_stage_ids, task_dag) + return state + + def cache_write(self, stage_id, scope_name, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + scope_name : Str + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateCacheWrite(self, stage_id, scope_name, task_dag) + return state + + def pragma(self, stage_id, it, pragma_type): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + pragma_type : Str + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StatePragma(self, stage_id, it, pragma_type) + + def rfactor(self, stage_id, it, factor_iter_id, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + factor_iter_id : Int + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateRfactor(self, stage_id, it, factor_iter_id, + task_dag) + return state + + def storage_align(self, stage_id, it, factor, offset): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + factor : Int + offset : Int + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateStorageAlign(self, stage_id, it, factor, offset) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index feaefe9f8e9f..1e33068e4965 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1166,6 +1166,8 @@ std::pair > ComputeDAG::ReplaySteps( return std::make_pair(schedule, operator->()->tensors); } +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { return ComputeDAGNode::make(tensors); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index f01899c4c793..ebea5a1e472a 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -3,11 +3,13 @@ */ #include "loop_state.h" #include +#include #include "utils.h" namespace tvm { namespace ansor { +TVM_REGISTER_OBJECT_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); Stage StageNode::make(te::Operation op) { @@ -65,6 +67,16 @@ Stage StageNode::make(te::Operation op, StageType op_type, return Stage(node); } +TVM_REGISTER_GLOBAL("ansor.StageGetIterator") + .set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators") + .set_body_typed([](const Stage& stage) { + return Array(stage->iters); + }); + State StateNode::make_empty_state() { auto node = make_object(); node->attach_map = AttachMapNode::make(); @@ -873,6 +885,143 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } +TVM_REGISTER_GLOBAL("ansor.StateGetStage") + .set_body_typed([](const State& state, int index) { + return state->stages[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize") + .set_body_typed([](const State& state) { + return static_cast(state->transform_steps.size()); + }); + +TVM_REGISTER_GLOBAL("ansor.StateReorder") + .set_body_typed([](State state, int stage_id, + const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, + bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + state.split(stage_id, it, len, inner_to_outer); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + state.follow_split(stage_id, it, src_step_id, n_split); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + std::vector array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + state.follow_fused_split(stage_id, it, array_src_step_ids, level, + factor_or_nparts); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFuse") + .set_body_typed([](State state, int stage_id, + const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + state.fuse(stage_id, its); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateVectorize") + .set_body_typed([](State state, int stage_id, + const Iterator& it) { + state.vectorize(stage_id, it); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateParallel") + .set_body_typed([](State state, int stage_id, + const Iterator& it) { + state.parallel(stage_id, it); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateUnroll") + .set_body_typed([](State state, int stage_id, + const Iterator& it, int max_unroll) { + state.unroll(stage_id, it, max_unroll); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateBindThread") + .set_body_typed([](State state, int stage_id, + const Iterator& it, int thread_type) { + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeAt") + .set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") + .set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeInline") + .set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StatePackForVec") + .set_body_typed([](State state, int stage_id, + const Iterator& target_iter, int vec_size) { + state.pack_for_vec(stage_id, target_iter, vec_size); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateCacheRead") + .set_body_typed([](State state, int stage_id, const std::string& scope_name, + const Array& reader_stage_ids, + const ComputeDAG& task_dag) { + std::vector array_reader_stage_ids; + for (const auto& i : reader_stage_ids) { + array_reader_stage_ids.push_back(i->value); + } + state.cache_read(stage_id, scope_name, array_reader_stage_ids, task_dag); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") + .set_body_typed([](State state, int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + state.cache_write(stage_id, scope_name, task_dag); + return state; + }); + void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 8cd8233ae9be..5f4a6a8dcef9 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -8,6 +8,7 @@ namespace tvm { namespace ansor { +TVM_REGISTER_NODE_TYPE(IteratorNode); TVM_REGISTER_OBJECT_TYPE(StepNode); /********** Reorder **********/ diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 9b430be99bd3..627ce02b60e1 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -69,6 +69,11 @@ class IteratorNode : public Object { IteratorType iter_type, IteratorAnnotation annotation, const std::vector* ori_iters = nullptr); + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + static constexpr const char *_type_key = "ansor.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index bbcef05f31fc..75a6cc00b802 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -19,13 +19,15 @@ #include #include -#include -#include #include -#include "../../src/ansor/loop_state.h" -#include "../../src/ansor/serialization.h" +#include + +#include + #include "../../src/ansor/feature.h" +#include "../../src/ansor/loop_state.h" #include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" +#include "../../src/ansor/serialization.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -35,16 +37,17 @@ tvm::Array matmul_func(int n, int m, int k) { Tensor B = placeholder({k, m}, DataType::Float(32), "B"); IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); const auto& C = compute( - {n, m}, - [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, + {n, m}, [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, "C"); return {A, B, C}; } tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, - int CI, int CO, int kernel_size, int strides, int padding, - int dilation = 1) { + int CI, int CO, + int kernel_size, + int strides, int padding, + int dilation = 1) { using namespace tvm; using namespace tvm::te; @@ -58,27 +61,27 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; - const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, - strides); + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); CHECK(conv->shape[2].as()->value == OH); CHECK(conv->shape[3].as()->value == OW); const auto& bias_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return conv[i][j][k][l] + bias[j][0][0]; + return conv[i][j][k][l] + bias[j][0][0]; }, "Bias_add"); const auto& bn_mul = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return bias_add[i][j][k][l] * bn_scale[j][0][0]; + return bias_add[i][j][k][l] * bn_scale[j][0][0]; }, "Bn_mul"); const auto& bn_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return bn_mul[i][j][k][l] + bn_offset[j][0][0]; + return bn_mul[i][j][k][l] + bn_offset[j][0][0]; }, "Bn_add"); const auto& out = topi::relu(bn_add); @@ -109,20 +112,22 @@ TEST(ComputeDAG, GetProducersConsumers) { std::unordered_set set; { std::vector> consumer_list = { - {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, - {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, - {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu} - }; + {data, padding}, {padding, conv}, {kernel, conv}, + {conv, bias_add}, {bias, bias_add}, {bias_add, bn_mul}, + {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, + {bn_add, relu}}; for (const auto& pair : consumer_list) { dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), 1); CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); } std::vector>> producer_list = { - {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}} - }; + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; for (const auto& pair : producer_list) { dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), pair.second.size()); @@ -138,18 +143,19 @@ TEST(ComputeDAG, GetProducersConsumers) { s0.compute_inline(padding); { std::vector> consumer_list = { - {data, conv}, {kernel, conv}, {conv, relu} - }; + {data, conv}, {kernel, conv}, {conv, relu}}; for (const auto& pair : consumer_list) { dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), 1); CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); } std::vector>> producer_list = { - {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}} - }; + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; for (const auto& pair : producer_list) { dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), pair.second.size()); @@ -170,15 +176,19 @@ TEST(ComputeDAG, InferBoundSerialization) { C++; const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); - s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its0[3], its1[3]}); + s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], + its1[3]}); s0.compute_at(C_global, C, s0->stages[C]->iters[3]); s0.split(C_global, s0->stages[C_global]->iters[2], {16}); int B_global = s0.cache_read(B, "global", {C_global}, dag); - C++; C_global++; + C++; + C_global++; s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); int A_global = s0.cache_read(A, "global", {C_global}, dag); - B++; B_global++; C++; C_global++; + B++; + B_global++; + C++; + C_global++; s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); const auto& s1 = dag.InferBound(s0); @@ -186,23 +196,26 @@ TEST(ComputeDAG, InferBoundSerialization) { dag.InferBound(&s2); const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); - CHECK_EQ(s1->stages[B_global]->iters[0]->range->extent.as()->value, - 512); - CHECK_EQ(s1->stages[B_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ(s1->stages[A_global]->iters[0]->range->extent.as()->value, - 1); - CHECK_EQ(s1->stages[A_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ(s1->stages[C_global]->iters[0]->range->extent.as()->value, - 64); + CHECK_EQ( + s1->stages[B_global]->iters[0]->range->extent.as()->value, + 512); + CHECK_EQ( + s1->stages[B_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ( + s1->stages[A_global]->iters[0]->range->extent.as()->value, 1); + CHECK_EQ( + s1->stages[A_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ( + s1->stages[C_global]->iters[0]->range->extent.as()->value, + 64); CHECK(std::equal_to()(s1, s2[0])); CHECK(std::equal_to()(s1, s3)); const auto& minp0 = MeasureInputNode::make( SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), - HardwareParams()), + tvm::target::llvm(), HardwareParams()), s0); const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); std::stringstream ss; @@ -242,7 +255,8 @@ TEST(Step, SplitFuseReorder) { CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); s0.fuse(2, {tio, tjo}); - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 2048); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, + 2048); s1.split(2, ti, {8, 2}); s1.split(2, tj, {32, 8}, false); @@ -271,10 +285,12 @@ TEST(Step, ComputeAtRootInline) { s0.compute_inline(bn_mul); s0.compute_inline(bias_add); s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); - const auto& conv_stage_attach = s0->attach_map->stage_to_attach_iter.find(conv); + const auto& conv_stage_attach = + s0->attach_map->stage_to_attach_iter.find(conv); std::pair iterkey(relu, 2); CHECK(conv_stage_attach->second == iterkey); - const auto& conv_iter_attach = s0->attach_map->iter_to_attached_stages.find(iterkey); + const auto& conv_iter_attach = + s0->attach_map->iter_to_attached_stages.find(iterkey); CHECK_EQ(conv_iter_attach->second.size(), 1); CHECK_EQ(conv_iter_attach->second[0], conv); std::stringstream ss; @@ -335,25 +351,28 @@ TEST(Step, CacheReadWrite) { int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; int padding = 1; Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); - Tensor kernel_data = placeholder({CO, CI, KH, KW}, DataType::Float(32), - "kernel_data"); - const auto& k_split = compute(kernel_data->shape, + Tensor kernel_data = + placeholder({CO, CI, KH, KW}, DataType::Float(32), "Kernel_data"); + const auto& k_split = compute( + kernel_data->shape, [&](const Array& i) { - return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, - div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); + return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, + div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); }, "Kernel_split"); - const auto& kernel = compute(kernel_data->shape, + const auto& kernel = compute( + kernel_data->shape, [&](Var i, Var j, Var k, Var l) { - return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; + return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; }, "Kernel"); - const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, stride, - stride); + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, stride, stride); const auto& relu = topi::relu(conv); - const auto& out = compute(relu->shape, + const auto& out = compute( + relu->shape, [&](Var i, Var j, Var k, Var l) { - return data[i][j][k][l] + relu[i][j][k][l]; + return data[i][j][k][l] + relu[i][j][k][l]; }, "Add"); return {data, kernel_data, out}; @@ -372,15 +391,20 @@ TEST(Step, CacheReadWrite) { // 1: simple cache_write with compute_at int conv_global = s0.cache_write(conv, "global", dag); - conv++; relu++; add++; + conv++; + relu++; + add++; s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); // 2: simple cache_read with compute_at int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); - conv_global++; conv++; relu++; add++; + conv_global++; + conv++; + relu++; + add++; s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); std::stringstream ss; - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,9)\n" @@ -425,25 +449,45 @@ TEST(Step, CacheReadWrite) { // 3: two level cache_read with compute_at // preparing for GPU's shared memory & local memory int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; - int pad_temp_shared = s0.cache_read(pad_temp_global, "shared", {conv_global}, - dag); - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; + int pad_temp_shared = + s0.cache_read(pad_temp_global, "shared", {conv_global}, dag); + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; s0.compute_at(pad_temp_global, conv_global, s0->stages[conv_global]->iters[2]); s0.compute_at(pad_temp_shared, conv_global, s0->stages[conv_global]->iters[4]); // 4: cache_read with multi readers - // This stage cannot be compute at to its consumer + // This stage cannot be compute at to its consumer s0.cache_read(data, "global", {pad_temp, add}, dag); - pad_temp++; pad_temp_global++; pad_temp_shared++; - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; + pad_temp++; + pad_temp_global++; + pad_temp_shared++; + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; ss.str(std::string()); - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,7)\n" @@ -517,7 +561,7 @@ TEST(Step, CacheReadWrite) { // To be fixed in the future s0.cache_write(kernel_split, "global", dag); ss.str(std::string()); - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,7)\n" @@ -598,8 +642,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { // FollowSplitStep currently only support `inner_to_outer = true` const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); int split_step0 = s0->transform_steps.size() - 1; - // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, false); - // int split_step1 = s0->transform_steps.size() - 1; + // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, + // false); int split_step1 = s0->transform_steps.size() - 1; for (int level = 1; level <= 5; level++) { State tmp = s0; tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, @@ -610,7 +654,7 @@ TEST(Step, FollowSplitFollowFusedSplit) { const auto& stage_C_global = tmp->stages[C_global]; for (int i = 0; i < level; i++) { CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, - stage_C_global->iters[i]->range->extent.as()->value); + stage_C_global->iters[i]->range->extent.as()->value); } // for (int i = 0; i < level; i++) { // CHECK(stage_C->iters[i+5]->range->extent.as()->value == @@ -627,7 +671,7 @@ TEST(Step, FollowSplitFollowFusedSplit) { } s0.reorder(C, its); for (int i = 0; i < 5; i++) { - s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i+1]}); + s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i + 1]}); } for (int level = 0; level < 4; level++) { State tmp = s0; @@ -635,8 +679,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { {split_step0, split_step1}, level, false); const auto& stage_C = tmp->stages[C]; const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, - stage_C_global->iters[0]->range->extent.as()->value); + CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, + stage_C_global->iters[0]->range->extent.as()->value); } for (int level = 0; level < 4; level++) { State tmp = s0; @@ -644,8 +688,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { {split_step0, split_step1}, level, true); const auto& stage_C = tmp->stages[C]; const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, - stage_C_global->iters[1]->range->extent.as()->value); + CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, + stage_C_global->iters[1]->range->extent.as()->value); } } @@ -676,10 +720,10 @@ TEST(Feature, ExtractionMatmul) { std::vector> features; std::vector feature_names; GetPerStmtFeatureName(max_n_bufs, &feature_names); - GetPerStmtFeaturesFromStates({s0}, + GetPerStmtFeaturesFromStates( + {s0}, SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), - HardwareParams()), + tvm::target::llvm(), HardwareParams()), max_n_bufs, 0, &features); int num_states = 1; CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); @@ -704,7 +748,7 @@ class MetaTileRewritePolicyNodeTest { policy->SynthesizeMetaStructure(meta_structures); } void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states) { + int out_size, std::vector* out_states) { policy->SampleInitPopulation(meta_structures, out_size, out_states); } tvm::runtime::ObjectPtr policy; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py new file mode 100644 index 000000000000..4782f9130cea --- /dev/null +++ b/tests/python/unittest/test_ansor_common.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +from tvm import ansor +import topi + + +def matmul_nkkm(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum( + A[i][k] * B[k][j], axis=[k]), name='C') + + return [A, B, C] + + +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='Data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') + bias = te.placeholder((CO, 1, 1), name='Bias') + bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='Bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='Bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='Bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + + +def test_compute_dag_basic(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + print(dag) + print(dag.access_analyzer) + print(dag.get_init_state()) + + +def test_state_split_fuse_reorder(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + s1 = s0 + ti = s0.stage(2).iterator(0) + tj = s0.stage(2).iterator(1) + tk = s0.stage(2).iterator(2) + + assert ti.range.extent == 512 + + s0 = s0.split(2, ti, [16]) + assert s0.stage(2).iterator(0).range.extent == 32 + assert s0.stage(2).iterator(1).range.extent == 16 + tio = s0.stage(2).iterator(0) + tii = s0.stage(2).iterator(1) + + s0 = s0.split(2, tj, [8]) + assert s0.stage(2).iterator(2).range.extent == 64 + assert s0.stage(2).iterator(3).range.extent == 8 + tjo = s0.stage(2).iterator(2) + tji = s0.stage(2).iterator(3) + + s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) + assert s0.stage(2).iterator(2).range.extent == 512 + + s0 = s0.fuse(2, [tio, tjo]) + assert s0.stage(2).iterator(0).range.extent == 2048 + + s1 = s1.split(2, ti, [8, 2]) + s1 = s1.split(2, tj, [32, 8], False) + assert s1.stage(2).iterator(0).range.extent == 32 + assert s1.stage(2).iterator(1).range.extent == 8 + assert s1.stage(2).iterator(2).range.extent == 2 + assert s1.stage(2).iterator(3).range.extent == 32 + assert s1.stage(2).iterator(4).range.extent == 8 + assert s1.stage(2).iterator(5).range.extent == 2 + + +def test_state_compute_at_root_inline(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + # data, padding, kernel = 0, 1, 2 + conv = 3 + # bias = 4 + bias_add = 5 + # bn_scale = 6 + bn_mul = 7 + # bn_offset = 8 + bn_add, relu = 9, 10 + + s0 = dag.get_init_state() + s0 = s0.compute_inline(bn_add) + s0 = s0.compute_inline(bn_mul) + s0 = s0.compute_inline(bias_add) + s0 = s0.compute_at(conv, relu, s0.stage(relu).iterator(2)) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + s0 = s0.compute_root(conv) + s0 = s0.compute_root(bn_mul) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + "for i (None)\n" + \ + " for j (None)\n" + \ + " for k (None)\n" + \ + " for l (None)\n" + \ + " Bn_mul = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + +def test_state_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + out = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, out]) + data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + + # 0: init state + s0 = dag.get_init_state() + ori_its = s0.stage(add).iterators() + s0 = s0.split(add, s0.stage(add).iterator(0), [2]) + s0 = s0.reorder(add, [s0.stage(add).iterator(0), ori_its[1], + s0.stage(add).iterator(1), ori_its[2], ori_its[3]]) + s0 = s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + s0 = s0.cache_write(conv, "global", dag) + conv_global = conv + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) + + # 2: simple cache_read with compute_at + s0 = s0.cache_read(kernel, "global", [conv_global], dag) + kernel_global = kernel + 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(kernel_global, conv_global, + s0.stage(conv_global).iterator(4)) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + s0 = s0.cache_read(pad_temp, "global", [conv_global], dag) + pad_temp_global = pad_temp + 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) + pad_temp_shared = pad_temp_global + 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(pad_temp_global, conv_global, + s0.stage(conv_global).iterator(2)) + s0 = s0.compute_at(pad_temp_shared, conv_global, + s0.stage(conv_global).iterator(4)) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + s0 = s0.cache_read(data, "global", [pad_temp, add], dag) + pad_temp += 1 + pad_temp_global += 1 + pad_temp_shared += 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 5: cache_write with multi outputs + # See tests/cpp/ansor_test.cc for more information + s0 = s0.cache_write(kernel_split, "global", dag) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0_c (0,512)\n" + \ + " for i1_c (0,512)\n" + \ + " for i2_c (0,3)\n" + \ + " for i3_c (0,3)\n" + \ + " Kernel_split.global = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + +def test_follow_split_follow_fused_split(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + + s0 = s0.cache_write(C, "global", dag) + C_global = C + C += 1 + + s0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) + split_step0 = s0.transform_steps_size() - 1 + for level in range(1, 6): + tmp = s0 + tmp = tmp.follow_split(C_global, tmp.stage( + C_global).iterator(0), split_step0, level) + for i in range(0, level): + assert tmp.stage(C).iterator(i).range.extent == \ + tmp.stage(C_global).iterator(i).range.extent + + s0 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) + split_step1 = s0.transform_steps_size() - 1 + its = s0.stage(C).iterators() + s0 = s0.reorder(C, [its[0], its[5], its[1], its[6], its[2], its[7], + its[3], its[8], its[4], its[9]]) + s0 = s0.fuse(C, [s0.stage(C).iterator(0), s0.stage(C).iterator(1)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(1), s0.stage(C).iterator(2)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(2), s0.stage(C).iterator(3)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(3), s0.stage(C).iterator(4)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(4), s0.stage(C).iterator(5)]) + for level in range(0, 4): + tmp = s0 + tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, False) + assert tmp.stage(C).iterator(level+1).range.extent == \ + tmp.stage(C_global).iterator(0).range.extent + for level in range(0, 4): + tmp = s0 + tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, True) + assert tmp.stage(C).iterator(level+1).range.extent == \ + tmp.stage(C_global).iterator(1).range.extent + + +def test_rfactor(): + pass + + +if __name__ == "__main__": + test_compute_dag_basic() + test_state_split_fuse_reorder() + test_state_compute_at_root_inline() + test_state_cache_read_write() + test_follow_split_follow_fused_split() + test_rfactor() From 2032a64356dc88e341162281b94746e61cdabfe2 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 4 Jun 2020 16:05:07 +0800 Subject: [PATCH 06/78] Add Python API: Measure & Task (#7) * Update the return value of state operation * Add task * Copy measure.py & utils.py * Fix LocalBuilder * Fix LocalRunner --- python/tvm/ansor/__init__.py | 2 + python/tvm/ansor/compute_dag.py | 50 +- python/tvm/ansor/measure.py | 434 +++++++++++++++++ python/tvm/ansor/state.py | 91 +++- python/tvm/ansor/task.py | 59 +++ python/tvm/ansor/utils.py | 229 +++++++++ src/ansor/compute_dag.cc | 33 +- src/ansor/compute_dag.h | 8 +- src/ansor/loop_state.cc | 522 +++++++++++---------- src/ansor/measure.cc | 202 +++++--- src/ansor/search_task.cc | 66 ++- src/ansor/search_task.h | 10 +- tests/cpp/ansor_test.cc | 4 +- tests/python/unittest/test_ansor_common.py | 120 +++-- 14 files changed, 1401 insertions(+), 429 deletions(-) create mode 100644 python/tvm/ansor/measure.py create mode 100644 python/tvm/ansor/task.py create mode 100644 python/tvm/ansor/utils.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index aaa0e9c9174d..cb039cf07d5f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -18,3 +18,5 @@ """Namespace for Ansor autoSchedule""" from .compute_dag import ComputeDAG +from .task import SearchTask +from .measure import MeasureInput, LocalBuilder, LocalRunner diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 3c46440f75ba..a66a181f054c 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -25,10 +25,56 @@ from . import _ffi_api +class LayoutRewriteLevel(object): + NO_REWRITE = 0 # No layout rewrite + PLACEHOLDER_REWRITE = 1 # Only rewrite layout of placeholder in the compute dag + COMPUTE_REWRITE = 2 # Only rewrite compute body for new layout in the compute dag + BOTH_REWRITE = 3 # Rewrite both placeholder and compute body in the compute dag + + @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): + """ + Parameters + ---------- + tensors : List[Tensor] + """ + def __init__(self, tensors): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) - def get_init_state(self) -> State: - return self.init_state + def get_init_state(self): + """ Get init state of this ComputeDAG + + Returns + ------- + state : State + """ + return _ffi_api.ComputeDAGGetInitState(self) + + def apply_steps_from_state(self, state, layout_rewrite_level): + """ + Parameters + ---------- + state : State + layout_rewrite_level : LayoutRewriteLevel(***) + + Returns + ------- + sch : Schedule + args : List[Tensor] + """ + sch, args = _ffi_api.ComputeDAGApplyStepsFromState(self, state) + return sch, args + + def print_python_code_from_state(self, state): + """ + Parameters + ---------- + state : State + + Returns + ------- + str : Str + """ + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py new file mode 100644 index 000000000000..72dd3cbfcf92 --- /dev/null +++ b/python/tvm/ansor/measure.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Distributed measurement infrastructure to measure the runtime costs of tensor programs + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We implement these in python to utilize python's multiprocessing and error handling +""" +from typing import List +import os +import time +import shutil +import logging +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.target import build_config +from ..contrib import tar, ndk +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote +from .compute_dag import LayoutRewriteLevel + +from . import _ffi_api + +logger = logging.getLogger('ansor') + + +@tvm._ffi.register_object("ansor.MeasureInput") +class MeasureInput(Object): + """ + Parameters + ---------- + task : SearchTask + state : State + """ + + def __init__(self, task, state): + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state) + + +@tvm._ffi.register_object("ansor.BuildResult") +class BuildResult(Object): + """ + Parameters + ---------- + filename : Str + args : List[Tensor] + error_no : Int + error_msg : Str + time_cost : Float + """ + + def __init__(self, filename, args, error_no, error_msg, time_cost): + self.__init_handle_by_constructor__( + _ffi_api.BuildResult, filename, args, error_no, + error_msg if error_msg else "", time_cost) + + +@tvm._ffi.register_object("ansor.MeasureResult") +class MeasureResult(Object): + """ + Parameters + ---------- + costs : List[Float] + error_no : Int + error_msg : Str + all_cost : Float + timestamp : Float + """ + + def __init__(self, costs, error_no, error_msg, all_cost, timestamp): + self.__init_handle_by_constructor__( + _ffi_api.MeasureResult, costs, error_no, + error_msg if error_msg else "", all_cost, timestamp) + + +@tvm._ffi.register_object("ansor.Builder") +class Builder(Object): + def build(self, measure_inputs, verbose=0): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + verbost : Int + + Returns + ------- + res : List[BuildResult] + """ + return _ffi_api.BuilderBuild(self, measure_inputs, verbose) + + +@tvm._ffi.register_object("ansor.Runner") +class Runner(Object): + def run(self, measure_inputs, build_results, verbose=0): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + build_results : List[BuildResult] + + Returns + ------- + res : List[MeasureResult] + """ + return _ffi_api.RunnerRun(self, measure_inputs, build_results, verbose) + + +@tvm._ffi.register_object("ansor.LocalBuilder") +class LocalBuilder(Builder): + """ + Parameters + ---------- + timeout : Int + n_parallel : Int + build_func : Str + """ + + def __init__(self, + timeout=15, + n_parallel=multiprocessing.cpu_count(), + build_func='default'): + self.__init_handle_by_constructor__( + _ffi_api.LocalBuilder, timeout, n_parallel, build_func) + + +@tvm._ffi.register_object("ansor.LocalRunner") +class LocalRunner(Runner): + """ + Parameters + ---------- + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + + def __init__(self, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) + + +MAX_ERROR_MSG_LEN = 512 + + +class MeasureErrorNo(object): + """Error type for MeasureResult""" + NO_ERROR = 0 # No error + INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state + # Errors happen when compiling code on host (e.g. tvm.build) + COMPILE_HOST = 2 + COMPILE_DEVICE = 3 # Errors happen when compiling code on device + # (e.g. OpenCL JIT on the device) + RUNTIME_DEVICE = 4 # Errors happen when run program on device + WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output + BUILD_TIMEOUT = 6 # Timeout during compilation + RUN_TIMEOUT = 7 # Timeout during run + UNKNOWN_ERROR = 8 # Unknown error + + +def make_error_msg(): + error_msg = str(traceback.format_exc()) + if len(error_msg) > MAX_ERROR_MSG_LEN: + error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:] + return error_msg + + +global global_build_arguments +global global_run_arguments + + +def local_build_worker(index): + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + measure_inputs, build_func, timeout, verbose = global_build_arguments + assert isinstance(build_func, str) + if build_func == 'default': + build_func = tar.tar + elif build_func == 'ndk': + build_func = ndk.create_shared + else: + raise ValueError("Invalid build_func" + build_func) + + def timed_func(): + tic = time.time() + inp = measure_inputs[index] + task = inp.task + + error_no = MeasureErrorNo.NO_ERROR + error_msg = None + args = [] + + try: + sch, args = task.compute_dag.apply_steps_from_state( + inp.state, LayoutRewriteLevel.BOTH_REWRITE) + except Exception: + error_no = MeasureErrorNo.INSTANTIATION_ERROR + error_msg = make_error_msg() + + if error_no == 0: + dirname = tempfile.mkdtemp() + filename = os.path.join( + dirname, "tmp_func." + build_func.output_format) + + try: + with build_config(unroll_max_extent=task.hardware_params.max_unroll_vec): + func = build_module.build( + sch, args, target=task.target, target_host=task.target_host) + func.export_library(filename, build_func) + except Exception: + error_no = MeasureErrorNo.COMPILE_HOST + error_msg = make_error_msg() + else: + filename = "" + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print(".", end="") + else: + print(".E", end="") # Build error + return filename, args, error_no, error_msg, time.time() - tic + + res = call_func_with_timeout(timeout, timed_func) + if isinstance(res, TimeoutError): + if verbose >= 1: + print(".T", end="") # Build timeout + res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout + + return res + + +@tvm._ffi.register_func("ansor.local_builder.build") +def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, build_func: str, verbose: int): + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + global global_build_arguments + global_build_arguments = (inputs, build_func, timeout, verbose) + + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(local_build_worker, range(len(inputs))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(BuildResult(*res)) + + return results + + +@tvm._ffi.register_func("ansor.rpc_runner.run") +def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], + key: str, host: str, port: int, priority: int, timeout: float, + n_parallel: int, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + global global_run_arguments + global_run_arguments = (inputs, build_results, key, host, port, priority, timeout, number, + repeat, min_repeat_ms, cooldown_interval, verbose) + + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(rpc_run_worker, range(len(build_results))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return results + + +def rpc_run_worker(index): + inputs, build_results, key, host, port, priority, timeout, number, \ + repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments + + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + inp = inputs[index] + build_res = build_results[index] + + if build_res.error_no != MeasureErrorNo.NO_ERROR: + return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + + def timed_func(): + tic = time.time() + error_no = 0 + error_msg = None + try: + # upload built module + remote = request_remote(key, host, port, priority, timeout) + remote.upload(build_res.filename) + func = remote.load_module(os.path.split(build_res.filename)[1]) + ctx = remote.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + # clean up remote files + remote.remove(build_res.filename) + remote.remove(os.path.splitext(build_res.filename)[0] + '.so') + remote.remove('') + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + + time.sleep(cooldown_interval) + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + res = call_func_with_timeout(timeout, timed_func) + + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \ + timeout, time.time() + return res + + +@tvm._ffi.register_func("ansor.local_runner.run") +def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], + timeout: float, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + + def timed_func(inp, build_res): + tic = time.time() + error_no = 0 + error_msg = None + try: + func = module.load_module(build_res.filename) + ctx = ndarray.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + time.sleep(cooldown_interval) + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + measure_results = [] + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + for inp, build_res in zip(inputs, build_results): + if build_res.error_no != 0: + res = ( + MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + else: + res = call_func_with_timeout( + timeout, timed_func, args=(inp, build_res)) + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = ( + MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + timeout, time.time() + measure_results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return measure_results diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py index 9a8810190199..7de95a8a74af 100644 --- a/python/tvm/ansor/state.py +++ b/python/tvm/ansor/state.py @@ -25,21 +25,41 @@ @tvm._ffi.register_object("ansor.Iterator") class Iterator(Object): + """ ... + """ pass @tvm._ffi.register_object("ansor.Stage") class Stage(Object): + """ ... + """ def iterator(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + iter : Iterator + """ return _ffi_api.StageGetIterator(self, index) def iterators(self): + """ + Returns + ------- + iters : List[Iterator] + """ return _ffi_api.StageGetIterators(self) @tvm._ffi.register_object("ansor.State") class State(Object): + """ ... + """ def stage(self, index): """ @@ -93,10 +113,12 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateSplit(self, stage_id, it, lengths, - inner_to_outer) - return state + state, res_its = _ffi_api.StateSplit(self, stage_id, it, lengths, + inner_to_outer) + return state, res_its def follow_split(self, stage_id, it, src_step_id, n_split): """ @@ -115,10 +137,12 @@ def follow_split(self, stage_id, it, src_step_id, n_split): ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateFollowSplit(self, stage_id, it, src_step_id, - n_split) - return state + state, res_its = _ffi_api.StateFollowSplit(self, stage_id, it, + src_step_id, n_split) + return state, res_its def follow_fused_split(self, stage_id, it, src_step_ids, level, factor_or_nparts): @@ -140,10 +164,13 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateFollowFusedSplit(self, stage_id, it, src_step_ids, - level, factor_or_nparts) - return state + state, res_its = _ffi_api.StateFollowFusedSplit(self, stage_id, it, + src_step_ids, level, + factor_or_nparts) + return state, res_its def fuse(self, stage_id, iters): """ @@ -158,9 +185,11 @@ def fuse(self, stage_id, iters): ------- state : State The updated state + res_it : Iterator + The fused Iterator """ - state = _ffi_api.StateFuse(self, stage_id, iters) - return state + state, res_it = _ffi_api.StateFuse(self, stage_id, iters) + return state, res_it def vectorize(self, stage_id, it): """ @@ -175,9 +204,11 @@ def vectorize(self, stage_id, it): ------- state : State The updated state + res_it : Iterator + The vectorized Iterator """ - state = _ffi_api.StateVectorize(self, stage_id, it) - return state + state, res_it = _ffi_api.StateVectorize(self, stage_id, it) + return state, res_it def parallel(self, stage_id, it): """ @@ -192,9 +223,11 @@ def parallel(self, stage_id, it): ------- state : State The updated state + res_it : Iterator + The paralleled Iterator """ - state = _ffi_api.StateParallel(self, stage_id, it) - return state + state, res_it = _ffi_api.StateParallel(self, stage_id, it) + return state, res_it def unroll(self, stage_id, it, max_unroll=-1): """ @@ -210,9 +243,11 @@ def unroll(self, stage_id, it, max_unroll=-1): ------- state : State The updated state + res_it : Iterator + The unrolled Iterator """ - state = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) - return state + state, res_it = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) + return state, res_it def bind_thread(self, stage_id, it, thread_type): """ @@ -229,9 +264,12 @@ def bind_thread(self, stage_id, it, thread_type): ------- state : State The updated state + res_it : Iterator + The thread binded Iterator """ - state = _ffi_api.StateBindThread(self, stage_id, it, thread_type) - return state + state, res_it = _ffi_api.StateBindThread(self, stage_id, it, + thread_type) + return state, res_it def compute_at(self, stage_id, target_stage_id, target_iter): """ @@ -311,10 +349,12 @@ def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): ------- state : State The updated state + new_stage_id : Int + The added staged id """ - state = _ffi_api.StateCacheRead(self, stage_id, scope_name, - reader_stage_ids, task_dag) - return state + state, new_stage_id = _ffi_api.StateCacheRead(self, stage_id, + scope_name, reader_stage_ids, task_dag) + return state, int(new_stage_id) def cache_write(self, stage_id, scope_name, task_dag): """ @@ -329,9 +369,12 @@ def cache_write(self, stage_id, scope_name, task_dag): ------- state : State The updated state + new_stage_id : Int + The added staged id """ - state = _ffi_api.StateCacheWrite(self, stage_id, scope_name, task_dag) - return state + state, new_stage_id = _ffi_api.StateCacheWrite(self, stage_id, + scope_name, task_dag) + return state, int(new_stage_id) def pragma(self, stage_id, it, pragma_type): """ diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py new file mode 100644 index 000000000000..245cf4c727ae --- /dev/null +++ b/python/tvm/ansor/task.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + +@tvm._ffi.register_object("ansor.HardwareParams") +class HardwareParams(Object): + """ + Parameters + ---------- + num_cores : Int + vector_unit_bytes : Int + cache_line_bytes : Int + max_unroll_vec : Int + max_innermost_split_factor : Int + """ + + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor): + self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, + vector_unit_bytes, cache_line_bytes, max_unroll_vec, + max_innermost_split_factor) + + +@tvm._ffi.register_object("ansor.SearchTask") +class SearchTask(Object): + """ + Parameters + ---------- + dag : ComputeDAG + workload_key : Str + target : tvm.target + target_host : tvm.target + hardware_params : HardwareParams + """ + + def __init__(self, dag, workload_key, target, target_host=None, + hardware_params=None): + self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, + workload_key, target, target_host, hardware_params) diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py new file mode 100644 index 000000000000..0216549c184a --- /dev/null +++ b/python/tvm/ansor/utils.py @@ -0,0 +1,229 @@ +"""Common utilities""" +import multiprocessing +import multiprocessing.pool +import queue +import signal +import threading +import os + +import numpy as np + +try: + import psutil +except ImportError: + psutil = None + +from .. import rpc as _rpc +from tvm.tir import expr +from tvm.tir.transform import Simplify +from tvm.ir.transform import Sequential + + +def get_func_name(func): + """Get name of a function + + Parameters + ---------- + func: Function + The function + Returns + ------- + name: str + The name + """ + + return func.func_name if hasattr(func, 'func_name') else func.__name__ + + +def get_const_int(exp): + """Verifies expr is integer and get the constant value. + + Parameters + ---------- + exp : tvm.Expr or int + The input expression. + + Returns + ------- + out_value : int + The output. + """ + if isinstance(exp, int): + return exp + if not isinstance(exp, (expr.IntImm)): + opt = Sequential([Simplify()]) + exp = opt(exp) + if not isinstance(exp, (expr.IntImm)): + raise ValueError("Expect value to be constant int") + return exp.value + + +def get_const_tuple(in_tuple): + """Verifies input tuple is IntImm, returns tuple of int. + + Parameters + ---------- + in_tuple : tuple of Expr + The input. + + Returns + ------- + out_tuple : tuple of int + The output. + """ + return tuple(get_const_int(x) for x in in_tuple) + + +def to_str_round(x, decimal=6): + """Convert object to str and round float numbers""" + if isinstance(x, str): + return x + if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): + return "[" + ", ".join([to_str_round(y, decimal=decimal) + for y in x]) + "]" + if isinstance(x, dict): + return str({k: eval(to_str_round(v)) for k, v in x.items()}) + if isinstance(x, int): + return str(x) + if isinstance(x, (np.float32, np.float64, float)): + format_str = "%%.%df" % decimal + return format_str % x + raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x))) + + +def array_mean(arr): + """Mean function for tvm array (Array)""" + return sum(x.value for x in arr) / len(arr) + + +class NoDaemonProcess(multiprocessing.Process): + @property + def daemon(self): + return False + + @daemon.setter + def daemon(self, value): + pass + + +class NoDaemonContext(type(multiprocessing.get_context())): + Process = NoDaemonProcess + + +class NoDaemonPool(multiprocessing.pool.Pool): + """A no daemon pool version of multiprocessing.Pool. + This allows us to start new processings inside the worker function""" + + def __init__(self, *args, **kwargs): + kwargs['context'] = NoDaemonContext() + super().__init__(*args, **kwargs) + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + + +def call_func_with_timeout(timeout, func, args=(), kwargs=None): + """Call a function with timeout""" + def func_wrapper(que): + if kwargs: + que.put(func(*args, **kwargs)) + else: + que.put(func(*args)) + + que = multiprocessing.Queue(2) + process = multiprocessing.Process(target=func_wrapper, args=(que,)) + process.start() + process.join(timeout) + + try: + res = que.get(block=False) + except queue.Empty: + res = TimeoutError() + + # clean queue and process + kill_child_processes(process.pid) + process.terminate() + process.join() + que.close() + que.join_thread() + del process + del que + + return res + + +def request_remote(device_key, host=None, port=None, priority=1, timeout=60): + """Request a remote session + + Parameters + ---------- + device_key: string + The device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this session (units: second) + + Returns + ------ + session: RPCSession + """ + # connect to the tracker + host = host or os.environ['TVM_TRACKER_HOST'] + port = port or int(os.environ['TVM_TRACKER_PORT']) + + tracker = _rpc.connect_tracker(host, port) + remote = tracker.request(device_key, priority=priority, + session_timeout=timeout) + return remote + + +def check_remote(device_key, host=None, port=None, priority=100, timeout=10): + """ + Check the availability of a remote device + + Parameters + ---------- + device_key: string + device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this check (units: seconds). + + Returns + ------- + available: bool + True if can find available device + """ + + def _check(): + remote = request_remote(device_key, host, port, priority) + + t = threading.Thread(target=_check, ) + t.start() + t.join(timeout) + return not t.is_alive() diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 1e33068e4965..c9415a70c303 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -588,15 +588,6 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) return ComputeDAGNode::make(std::move(tens)); } -void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tensors", &tensors); - v->Visit("ops", &ops); - v->Visit("flop_ct", &flop_ct); - v->Visit("access_analyzer", &access_analyzer); - State s = Downcast(init_state); - v->Visit("init_state", &s); -} - // Implemented in multi_stage_policy.cc // Extract primitive iterators from a nested fused or splitted iterator's name extern void ExtractOriginalIterators(const std::string& name, std::set* rets); @@ -1166,9 +1157,6 @@ std::pair > ComputeDAG::ReplaySteps( return std::make_pair(schedule, operator->()->tensors); } -TVM_REGISTER_GLOBAL("ansor.ComputeDAG") -.set_body_typed([](Array tensors) { return ComputeDAGNode::make(tensors); }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { auto* node = static_cast(ref.get()); @@ -1262,5 +1250,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { + return ComputeDAGNode::make(tensors); +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") +.set_body_method(&ComputeDAG::GetInitState); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); + return Array{sch, return_tensors}; +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.PrintStepsAsPython(state->transform_steps); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 9d0708a77f1c..3b4c80c50ad8 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -93,7 +93,13 @@ class ComputeDAGNode : public Object { AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer ObjectRef init_state; // initial states - void VisitAttrs(tvm::AttrVisitor* v); + void VisitAttrs(tvm::AttrVisitor* v) { + LOG(INFO) << "ComputeDAG"; + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("access_analyzer", &access_analyzer); + } static ComputeDAG make(Array tensors); static ComputeDAG make_by_workload_key(const std::string& workload_key); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index ebea5a1e472a..e18d36e34581 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -2,8 +2,10 @@ * Copyright (c) 2020 by Contributors */ #include "loop_state.h" -#include + #include +#include + #include "utils.h" namespace tvm { @@ -16,15 +18,15 @@ Stage StageNode::make(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { node->op_type = kCompute; - auto *pop = op.as(); + auto* pop = op.as(); for (const auto& axis : pop->axis) { node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kSpace, kNone)); + axis->dom, kSpace, kNone)); } for (const auto& axis : pop->reduce_axis) { node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kReduce, kNone)); + axis->dom, kReduce, kNone)); } } else if (op->IsInstance()) { node->op_type = kPlaceholder; @@ -54,9 +56,8 @@ Stage StageNode::make(te::Operation op, StageType op_type, } Stage StageNode::make(te::Operation op, StageType op_type, - std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, - int storage_offset) { + std::vector&& iters, ComputeAtType compute_at, + int16_t auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -67,16 +68,6 @@ Stage StageNode::make(te::Operation op, StageType op_type, return Stage(node); } -TVM_REGISTER_GLOBAL("ansor.StageGetIterator") - .set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterators") - .set_body_typed([](const Stage& stage) { - return Array(stage->iters); - }); - State StateNode::make_empty_state() { auto node = make_object(); node->attach_map = AttachMapNode::make(); @@ -97,8 +88,8 @@ State StateNode::make(const Array& ops) { } State StateNode::make(const std::vector& stages, - const std::vector& transform_steps, - bool complete, ObjectRef aux_info) { + const std::vector& transform_steps, bool complete, + ObjectRef aux_info) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; @@ -131,31 +122,32 @@ std::vector State::split(int stage_id, const Iterator& it, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; - SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), lengths, - inner_to_outer); + SplitStep step = + SplitStepNode::make(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), + lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return DoSplitStep(step); } -std::vector State::follow_split(int stage_id, - const Iterator& it, int src_step_id, int n_split) { +std::vector State::follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStepNode::make(stage_id, - GetIndex(stage->iters, it), src_step_id, n_split); + FollowSplitStep step = FollowSplitStepNode::make( + stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return DoFollowSplitStep(step); } - std::vector State::follow_fused_split( int stage_id, const Iterator& it, const std::vector& src_step_ids, int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; - FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, - GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep step = + FollowFusedSplitStepNode::make(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return DoFollowFusedSplitStep(step); } @@ -179,16 +171,16 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make( - stage_id, GetIndex(stage->iters, it), kParallel); + AnnotationStep step = + AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, - GetIndex(stage->iters, it), kUnroll); + AnnotationStep step = + AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -206,8 +198,8 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, - GetIndex(target_stage->iters, target_iter)); + ComputeAtStep step = ComputeAtStepNode::make( + stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); return DoComputeAtStep(step); } @@ -227,8 +219,8 @@ void State::compute_inline(int stage_id) { void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { const Stage& stage = operator->()->stages[stage_id]; - PackForVecStep step = PackForVecStepNode::make(stage_id, - GetIndex(stage->iters, target_iter), vec_size); + PackForVecStep step = PackForVecStepNode::make( + stage_id, GetIndex(stage->iters, target_iter), vec_size); CopyOnWrite()->transform_steps.push_back(step); return DoPackForVecStep(step); } @@ -240,8 +232,8 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " << "kThreadY"; } - AnnotationStep step = AnnotationStepNode::make(stage_id, - GetIndex(stage->iters, it), thread_type); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -249,14 +241,14 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, int State::cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { - CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, - reader_stage_ids); + CacheReadStep step = + CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } int State::cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { + const ComputeDAG& task_dag) { CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); CopyOnWrite()->transform_steps.push_back(step); return DoCacheWriteStep(step, task_dag); @@ -265,14 +257,14 @@ int State::cache_write(int stage_id, const std::string& scope_name, void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; - PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), - pragma_type); + PragmaStep step = + PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); CopyOnWrite()->transform_steps.push_back(step); return DoPragmaStep(step); } int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag) { + const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); @@ -283,8 +275,8 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStepNode::make(stage_id, - GetIndex(stage->iters, it), factor, offset); + StorageAlignStep step = StorageAlignStepNode::make( + stage_id, GetIndex(stage->iters, it), factor, offset); CopyOnWrite()->transform_steps.push_back(step); return DoStorageAlignStep(step); } @@ -299,11 +291,9 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(iters), - stage->compute_at, - stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -324,7 +314,8 @@ std::vector State::DoSplitStepCommon( std::vector outs; for (size_t i = 0; i < lengths.size(); ++i) { - PrimExpr l; std::string name; + PrimExpr l; + std::string name; if (inner_to_outer) { l = lengths[lengths.size() - i - 1]; name = it->name + "." + std::to_string(lengths.size() - i); @@ -350,26 +341,26 @@ std::vector State::DoSplitStepCommon( range = Range::make_by_min_extent(tosplit_min, tosplit_extent); } if (inner_to_outer) { - outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, - kNone)); + outs.push_back( + IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); std::reverse(outs.begin(), outs.end()); } else { - outs.push_back(IteratorNode::make( - it->name + "." + std::to_string(lengths.size()), range, it->iter_type, - kNone)); + outs.push_back( + IteratorNode::make(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); } std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -396,8 +387,8 @@ std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { std::vector State::DoFollowFusedSplitStep( const FollowFusedSplitStep& step) { - const PrimExpr& length = step->ExtractSplitLength( - operator->()->transform_steps); + const PrimExpr& length = + step->ExtractSplitLength(operator->()->transform_steps); return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); } @@ -414,15 +405,14 @@ Iterator State::DoFuseStep(const FuseStep& step) { std::vector ori_iters; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { - CHECK_EQ(step->fused_ids[i], step->fused_ids[i-1] + 1); + CHECK_EQ(step->fused_ids[i], step->fused_ids[i - 1] + 1); } if (i != step->fused_ids.size() - 1) { const auto& iter_to_attached_stage = - operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, - step->fused_ids[i])) - != iter_to_attached_stage.end()) { + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair( + stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " "that have been attached by some stages"; } @@ -451,8 +441,8 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::make_by_min_extent(0, new_extent); } - Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, - &ori_iters); + Iterator new_it = + IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); @@ -462,9 +452,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -477,7 +467,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { } else if (i > end_id) { // move forward from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { // move to the fused id + } else { // move to the fused id from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, begin_id); } @@ -491,7 +481,7 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator it = stage->iters[step->iter_id]; Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters); + step->annotation, &it->ori_iters); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -508,8 +498,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { size_t s = it->name.size(); - if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && - it->name[s-1] <= '4') { + if (s >= 2 && it->name[s - 2] == '.' && it->name[s - 1] >= '1' && + it->name[s - 1] <= '4') { // We use a dangerous heuristic rule here : For multi level splitted // iterators, we assume their length does not change after compute_at. // Reason: These iterators are generated in MultiStagePolicy by multi @@ -519,14 +509,14 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { new_iters.push_back(it); } else { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters)); } } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), kIter, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = + StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, + stage->auto_unroll_max_step, stage->storage_offset); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } @@ -540,14 +530,14 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters)); } // update attach map StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), kRoot, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = + StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, + stage->auto_unroll_max_step, stage->storage_offset); pstate->attach_map.DeleteStage(step->stage_id); } @@ -560,9 +550,10 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) - << "Invalid compute_inline: Because there are some other stages " - "that are attached to the target stage"; + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), + 0) + << "Invalid compute_inline: Because there are some other stages " + "that are attached to the target stage"; } pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; @@ -576,7 +567,8 @@ void State::DoPackForVecStep(const PackForVecStep& step) { // Common part for steps that add new stages // (e.g. CacheReadStep, CacheWriteStep, RfactorStep) void AddStageModificationSteps(size_t step_id, - const std::vector& transform_steps, std::vector* replay_steps) { + const std::vector& transform_steps, + std::vector* replay_steps) { const Step& step = transform_steps[step_id]; if (step->IsInstance() || step->IsInstance()) { @@ -615,14 +607,15 @@ int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { // target -> target + target_store // Should update target's op, insert new stage, update the later stage's op pstate->stages[step->stage_id].CopyOnWrite()->op = - operator->()->task_dag->ops[step->stage_id]; - pstate->stages.insert(pstate->stages.begin() + step->stage_id + 1, + operator->()->task_dag->ops[step->stage_id]; + pstate->stages.insert( + pstate->stages.begin() + step->stage_id + 1, StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id + 1, 1); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id + 1, 1); return step->stage_id + 1; } @@ -637,8 +630,9 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { } } - int last_dag_op_size = pstate->task_dag.defined() ? - pstate->task_dag->ops.size() : dag->ops.size(); + int last_dag_op_size = pstate->task_dag.defined() + ? pstate->task_dag->ops.size() + : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; CHECK_GE(added_ops, 1); @@ -646,7 +640,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // target -> target_compute + target // Assume target stage has never been applied any steps before cache_write // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert(pstate->stages.begin() + step->stage_id, + pstate->stages.insert( + pstate->stages.begin() + step->stage_id, StageNode::make(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); @@ -657,7 +652,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // for more information // TODO(jcf94): Fix this if (added_ops == 2) { - pstate->stages.insert(pstate->stages.begin() + next_stage_id, + pstate->stages.insert( + pstate->stages.begin() + next_stage_id, StageNode::make(operator->()->task_dag->ops[next_stage_id])); next_stage_id++; } else if (added_ops > 2) { @@ -666,8 +662,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, added_ops); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id, added_ops); return step->stage_id; } @@ -702,18 +698,20 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { // target -> target_compute + target // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert(pstate->stages.begin() + step->stage_id, + pstate->stages.insert( + pstate->stages.begin() + step->stage_id, StageNode::make(operator->()->task_dag->ops[step->stage_id])); // maintain the compute_at type of target stage - Stage target_stage = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage target_stage = + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); target_stage.CopyOnWrite()->compute_at = compute_at_type; pstate->stages[step->stage_id + 1] = target_stage; for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id, 1); return step->stage_id; } @@ -777,7 +775,6 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { } } - void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; @@ -786,15 +783,15 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " auto_unroll: " - << stage->auto_unroll_max_step << "\n"; + *os << stage->op->func_name() + << " auto_unroll: " << stage->auto_unroll_max_step << "\n"; } if (stage->storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " storage_offset: " - << stage->storage_offset << "\n"; + *os << stage->op->func_name() + << " storage_offset: " << stage->storage_offset << "\n"; } size_t indent = 0; @@ -802,26 +799,46 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, const Iterator& iter = stage->iters[i]; if (!(delete_trivial_loop && iter->range.defined() && - is_one(iter->range->extent))) { + is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } switch (iter->annotation) { - case kNone: *os << "for "; break; - case kUnroll: *os << "unroll "; break; - case kParallel: *os << "parallel "; break; - case kVectorize: *os << "vectorize "; break; - case kVThread: *os << "vthread "; break; - case kBlockX: *os << "gpu.blockIdx.x "; break; - case kBlockY: *os << "gpu.blockIdx.y "; break; - case kThreadX: *os << "gpu.threadIdx.x "; break; - case kThreadY: *os << "gpu.threadIdx.y "; break; + case kNone: + *os << "for "; + break; + case kUnroll: + *os << "unroll "; + break; + case kParallel: + *os << "parallel "; + break; + case kVectorize: + *os << "vectorize "; + break; + case kVThread: + *os << "vthread "; + break; + case kBlockX: + *os << "gpu.blockIdx.x "; + break; + case kBlockY: + *os << "gpu.blockIdx.y "; + break; + case kThreadX: + *os << "gpu.threadIdx.x "; + break; + case kThreadY: + *os << "gpu.threadIdx.y "; + break; } if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")" << "\n"; + << iter->range->extent << ")" + << "\n"; } else { - *os << iter->name << " (None)" << "\n"; + *os << iter->name << " (None)" + << "\n"; } indent += 2; @@ -885,6 +902,110 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = + std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMapNode::make(); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterator") + .set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators") + .set_body_typed([](const Stage& stage) { + return Array(stage->iters); + }); + TVM_REGISTER_GLOBAL("ansor.StateGetStage") .set_body_typed([](const State& state, int index) { return state->stages[index]; @@ -908,21 +1029,20 @@ TVM_REGISTER_GLOBAL("ansor.StateReorder") TVM_REGISTER_GLOBAL("ansor.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, - bool inner_to_outer) { + const Array& lengths, bool inner_to_outer) { std::vector len; for (const auto& i : lengths) { len.push_back(i); } - state.split(stage_id, it, len, inner_to_outer); - return state; + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, int n_split) { - state.follow_split(stage_id, it, src_step_id, n_split); - return state; + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") @@ -933,9 +1053,9 @@ TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") for (const auto& i : src_step_ids) { array_src_step_ids.push_back(i->value); } - state.follow_fused_split(stage_id, it, array_src_step_ids, level, - factor_or_nparts); - return state; + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFuse") @@ -945,36 +1065,35 @@ TVM_REGISTER_GLOBAL("ansor.StateFuse") for (const auto& i : iters) { its.push_back(i); } - state.fuse(stage_id, its); - return state; + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateVectorize") - .set_body_typed([](State state, int stage_id, - const Iterator& it) { - state.vectorize(stage_id, it); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateParallel") - .set_body_typed([](State state, int stage_id, - const Iterator& it) { - state.parallel(stage_id, it); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateUnroll") - .set_body_typed([](State state, int stage_id, - const Iterator& it, int max_unroll) { - state.unroll(stage_id, it, max_unroll); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it, + int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateBindThread") - .set_body_typed([](State state, int stage_id, - const Iterator& it, int thread_type) { - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it, + int thread_type) { + const auto& res = + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateComputeAt") @@ -997,8 +1116,8 @@ TVM_REGISTER_GLOBAL("ansor.StateComputeInline") }); TVM_REGISTER_GLOBAL("ansor.StatePackForVec") - .set_body_typed([](State state, int stage_id, - const Iterator& target_iter, int vec_size) { + .set_body_typed([](State state, int stage_id, const Iterator& target_iter, + int vec_size) { state.pack_for_vec(stage_id, target_iter, vec_size); return state; }); @@ -1011,110 +1130,17 @@ TVM_REGISTER_GLOBAL("ansor.StateCacheRead") for (const auto& i : reader_stage_ids) { array_reader_stage_ids.push_back(i->value); } - state.cache_read(stage_id, scope_name, array_reader_stage_ids, task_dag); - return state; + int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, + task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; }); TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") .set_body_typed([](State state, int stage_id, const std::string& scope_name, const ComputeDAG& task_dag) { - state.cache_write(stage_id, scope_name, task_dag); - return state; + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; }); -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, - int target_iter_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the current entry of stage - DeleteStageEntry(pnode, stage_id); - - // store the new relation - IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, - target_iter_id); - pnode->iter_to_attached_stages[iter_key].push_back(stage_id); -} - -void AttachMap::DeleteStage(int stage_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the entry of old stage - DeleteStageEntry(pnode, stage_id); -} - -void AttachMap::ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters) { - AttachMapNode* pnode = CopyOnWrite(); - - CHECK_EQ(old_iters.size(), new_iters.size()); - for (size_t i = 0; i < old_iters.size(); ++i) { - auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); - if (entry == pnode->iter_to_attached_stages.end()) { - continue; - } - - // replace iter in the value of `stage_to_attach_iter` - for (const auto& s : entry->second) { - pnode->stage_to_attach_iter[s] = new_iters[i]; - } - - // replace iter in the key of `iter_to_attached_stages` - std::vector attached_stages = std::move(entry->second); - pnode->iter_to_attached_stages.erase(entry); - pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); - } -} - -void AttachMap::DeleteStageEntry(AttachMapNode *pnode, int stage_id) { - auto old_entry = pnode->stage_to_attach_iter.find(stage_id); - if (old_entry != pnode->stage_to_attach_iter.end()) { - // delete value in `iter_to_attached_stages` - auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); - DeleteItem(&entry2->second, stage_id); - if (entry2->second.size() == 0) { - pnode->iter_to_attached_stages.erase(entry2); - } - // delete key in `stage_to_attach_iter` - pnode->stage_to_attach_iter.erase(old_entry); - } -} - -AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMapNode::make(); - auto pmap = map.CopyOnWrite(); - for (const auto& x : operator->()->stage_to_attach_iter) { - auto key = x.first; - if (key >= start_id) { - key += offset; - } - auto value = x.second; - if (value.first >= start_id) { - value.first += offset; - } - pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); - } - for (const auto& x : operator->()->iter_to_attached_stages) { - auto key = x.first; - if (key.first >= start_id) { - key.first += offset; - } - auto value = x.second; - for (auto& i : value) { - if (i >= start_id) { - i += offset; - } - } - pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); - } - return map; -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 1bae02b3f2c5..b2cff24973bc 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -3,12 +3,13 @@ */ #include "measure.h" // #include -#include #include +#include + +#include #include #include #include -#include // #include "search_policy/search_policy.h" namespace tvm { @@ -25,16 +26,16 @@ TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); -const char *ErrorNoToStr[] = { - "NoError", - "InstantiationError", - "CompileHostError", - "CompileDeviceError", - "RuntimeDeviceError", - "WrongAnswerError", - "BuildTimeoutError", - "RunTimeoutError", - "UnknownError", +const char* ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", }; // Maker @@ -52,8 +53,9 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult BuildResultNode::make(std::string filename, Array args, int error_no, - std::string error_msg, double time_cost) { +BuildResult BuildResultNode::make(std::string filename, Array args, + int error_no, std::string error_msg, + double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); @@ -64,7 +66,8 @@ BuildResult BuildResultNode::make(std::string filename, Array args, } MeasureResult MeasureResultNode::make(Array costs, int error_no, - std::string error_msg, double all_cost, double timestamp) { + std::string error_msg, double all_cost, + double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; @@ -84,7 +87,8 @@ MeasureResult MeasureResultNode::copy() const { return MeasureResult(node); } -Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { +Builder LocalBuilderNode::make(int timeout, int n_parallel, + const std::string& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; @@ -93,9 +97,11 @@ Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& b } // LocalBuilder and LocalRunner -Array LocalBuilderNode::Build(const Array &inputs, int verbose) { +Array LocalBuilderNode::Build(const Array& inputs, + int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { - Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); + Array results = + (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; } else { LOG(FATAL) << "ansor.local_builder.build is not registered"; @@ -103,9 +109,10 @@ Array LocalBuilderNode::Build(const Array &inputs, in return Array(); } -Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { +Runner RPCRunnerNode::make(const std::string& key, const std::string& host, + int port, int priority, int timeout, int n_parallel, + int number, int repeat, int min_repeat_ms, + double cooldown_interval) { auto node = make_object(); node->key = key; node->host = host; @@ -124,9 +131,9 @@ Array RPCRunnerNode::Run(const Array& inputs, const Array& build_results, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { - Array results = (*f)(inputs, build_results, key, host, port, priority, - timeout, n_parallel, number, repeat, - min_repeat_ms, cooldown_interval, verbose); + Array results = (*f)( + inputs, build_results, key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.rpc_runner.run is not registered"; @@ -145,12 +152,13 @@ Runner LocalRunnerNode::make(int timeout, int number, int repeat, return Runner(node); } -Array LocalRunnerNode::Run(const Array& inputs, - const Array& build_results, - int verbose) { +Array LocalRunnerNode::Run( + const Array& inputs, const Array& build_results, + int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { - Array results = (*f)(inputs, build_results, timeout, number, - repeat, min_repeat_ms, cooldown_interval, verbose); + Array results = + (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, + cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.local_runner.run is not registered"; @@ -167,8 +175,9 @@ ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, node->runner = std::move(runner); node->callbacks = std::move(callbacks); node->verbose = verbose; - node->max_continous_error = max_continous_error < 0 ? - DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + node->max_continous_error = max_continous_error < 0 + ? DEFAULT_MAX_CONTINOUS_ERROR + : max_continous_error; return ProgramMeasurer(node); } @@ -192,12 +201,14 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, batch_size = builder->n_parallel * 2; } - StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" - << std::endl; + StdCout(verbose) << "Get " << inputs.size() + << " programs for measure. (This may take a while)" + << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { - std::vector input_batch(inputs.begin() + i, - inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector input_batch( + inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); std::vector result_batch; // build and run @@ -207,7 +218,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, for (size_t j = 0; j < input_batch.size(); ++j) { double flops; if (result_batch[j]->error_no == 0) { - flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + flops = + task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); error_ct = 0; } else { flops = 0.0; @@ -225,8 +237,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, if (verbose >= 1) { std::cout << std::fixed << std::setprecision(2); std::cout << "===============================================\n"; - std::cout << "No: " << ct - << "\tGFLOPS: " << flops / 1e9 << " / " << best_flops[workload_key] / 1e9 + std::cout << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n"; std::cout << "===============================================\n"; std::cout << input_batch[j]->state << "\n"; @@ -261,7 +273,8 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Call builder and runner Array build_res_batch = builder->Build(input_batch, verbose); - Array result_batch = runner->Run(input_batch, build_res_batch, verbose); + Array result_batch = + runner->Run(input_batch, build_res_batch, verbose); // Store result batch for (auto& res : result_batch) { @@ -271,44 +284,89 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Printing functions TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - p->stream << "MeasureInput()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { - p->stream << "MeasureResult(cost:["; - auto old_config = p->stream.precision(4); - for (size_t i = 0; i < node->costs.size(); ++i) { - auto pf = node->costs[i].as(); - CHECK(pf != nullptr); - p->stream << pf->value; - if (i != node->costs.size() - 1) { - p->stream << ","; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; } - } - p->stream.precision(old_config); - p->stream << "], "; - p->stream << "error_no:" << 0 << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } else { - p->stream << "MeasureResult(" - << "error_type:" << ErrorNoToStr[node->error_no] << ", " - << "error_msg:" << node->error_msg << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } -}); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - p->stream << "BuildResult(" << node->filename << ", " << node->error_no - << ", " << node->time_cost << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; + }); + +TVM_REGISTER_GLOBAL("ansor.MeasureInput") + .set_body_typed([](SearchTask task, State state) { + return MeasureInputNode::make(task, state); + }); + +TVM_REGISTER_GLOBAL("ansor.BuildResult") + .set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResultNode::make(filename, args, error_no, error_msg, + time_cost); + }); + +TVM_REGISTER_GLOBAL("ansor.MeasureResult") + .set_body_typed([](Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { + return MeasureResultNode::make(costs, error_no, error_msg, all_cost, + timestamp); + }); + +TVM_REGISTER_GLOBAL("ansor.BuilderBuild") + .set_body_typed([](const Builder& builder, + const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); + }); + +TVM_REGISTER_GLOBAL("ansor.RunnerRun") + .set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, int verbose) { + return runner->Run(inputs, build_results, verbose); + }); + +TVM_REGISTER_GLOBAL("ansor.LocalBuilder") + .set_body_typed([](int timeout, int n_parallel, + const std::string& build_func) { + return LocalBuilderNode::make(timeout, n_parallel, build_func); + }); + +TVM_REGISTER_GLOBAL("ansor.LocalRunner") + .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, + cooldown_interval); + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index b9cda9168b9e..93f3f60ea768 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -2,20 +2,23 @@ * Copyright (c) 2020 by Contributors */ #include "search_task.h" -#include -#include + #include -#include +#include +#include + #include +#include namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(HardwareParamsNode); -TVM_REGISTER_OBJECT_TYPE(SearchTaskNode); +TVM_REGISTER_NODE_TYPE(HardwareParamsNode); +TVM_REGISTER_NODE_TYPE(SearchTaskNode); HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, + int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor) { auto node = make_object(); node->num_cores = num_cores; @@ -40,21 +43,19 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( auto ctx = TVMContext{kDLGPU, 0}; auto func = tvm::runtime::Registry::Get("device_api.gpu"); CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = static_cast(((*func)()).operator void*()); + auto device_api = + static_cast(((*func)()).operator void*()); tvm::runtime::TVMRetValue ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); p_hardware_params->max_shared_memory_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); p_hardware_params->max_registers_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); p_hardware_params->max_threads_per_block = ret; @@ -73,16 +74,15 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( auto ctx = TVMContext{kDLOpenCL, 0}; auto func = tvm::runtime::Registry::Get("device_api.opencl"); CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = static_cast(((*func)()).operator void*()); + auto device_api = + static_cast(((*func)()).operator void*()); tvm::runtime::TVMRetValue ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); p_hardware_params->max_shared_memory_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); p_hardware_params->max_threads_per_block = ret; @@ -99,9 +99,10 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return HardwareParams(); } - -SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, HardwareParams hardware_params) { +SearchTask SearchTaskNode::make(ComputeDAG compute_dag, + std::string workload_key, Target target, + Target target_host, + HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -116,5 +117,22 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key return SearchTask(node); } +TVM_REGISTER_GLOBAL("ansor.HardwareParams") + .set_body_typed([](int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + return HardwareParamsNode::make(num_cores, vector_unit_bytes, + cache_line_bytes, max_unroll_vec, + max_innermost_split_factor); + }); + +TVM_REGISTER_GLOBAL("ansor.SearchTask") + .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + return SearchTaskNode::make(compute_dag, workload_key, target, + target_host, hardware_params); + }); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 7db98a5197a5..9512013848b6 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -8,13 +8,16 @@ #define TVM_ANSOR_SEARCH_TASK_H_ #include + #include + #include "compute_dag.h" namespace tvm { namespace ansor { -class HardwareParams; class SearchTask; +class HardwareParams; +class SearchTask; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { @@ -54,12 +57,11 @@ class HardwareParamsNode : public Object { static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); - static constexpr const char *_type_key = "ansor.HardwareParams"; + static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); - /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { public: @@ -81,7 +83,7 @@ class SearchTaskNode : public Object { Target target, Target target_host, HardwareParams hardware_params); - static constexpr const char *_type_key = "ansor.SearchTask"; + static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 75a6cc00b802..e5a2c98c02a9 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -242,15 +242,15 @@ TEST(Step, SplitFuseReorder) { CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); its = s0.split(2, ti, {16}); + Iterator tio = its[0], tii = its[1]; CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); - Iterator tio = its[0], tii = its[1]; its = s0.split(2, tj, {8}); + Iterator tjo = its[0], tji = its[1]; CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); - Iterator tjo = its[0], tji = its[1]; s0.reorder(2, {tio, tjo, tk, tji, tii}); CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 4782f9130cea..da87ea5fe9cf 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -73,26 +73,26 @@ def test_state_split_fuse_reorder(): assert ti.range.extent == 512 - s0 = s0.split(2, ti, [16]) + s0, its = s0.split(2, ti, [16]) + tio = its[0] + tii = its[1] assert s0.stage(2).iterator(0).range.extent == 32 assert s0.stage(2).iterator(1).range.extent == 16 - tio = s0.stage(2).iterator(0) - tii = s0.stage(2).iterator(1) - s0 = s0.split(2, tj, [8]) + s0, its = s0.split(2, tj, [8]) + tjo = its[0] + tji = its[1] assert s0.stage(2).iterator(2).range.extent == 64 assert s0.stage(2).iterator(3).range.extent == 8 - tjo = s0.stage(2).iterator(2) - tji = s0.stage(2).iterator(3) s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) assert s0.stage(2).iterator(2).range.extent == 512 - s0 = s0.fuse(2, [tio, tjo]) - assert s0.stage(2).iterator(0).range.extent == 2048 + s0, res_it = s0.fuse(2, [tio, tjo]) + assert res_it.range.extent == 2048 - s1 = s1.split(2, ti, [8, 2]) - s1 = s1.split(2, tj, [32, 8], False) + s1, _ = s1.split(2, ti, [8, 2]) + s1, _ = s1.split(2, tj, [32, 8], False) assert s1.stage(2).iterator(0).range.extent == 32 assert s1.stage(2).iterator(1).range.extent == 8 assert s1.stage(2).iterator(2).range.extent == 2 @@ -186,22 +186,19 @@ def test_state_cache_read_write(): # 0: init state s0 = dag.get_init_state() ori_its = s0.stage(add).iterators() - s0 = s0.split(add, s0.stage(add).iterator(0), [2]) - s0 = s0.reorder(add, [s0.stage(add).iterator(0), ori_its[1], - s0.stage(add).iterator(1), ori_its[2], ori_its[3]]) + s0, its = s0.split(add, s0.stage(add).iterator(0), [2]) + s0 = s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0 = s0.compute_inline(relu) # 1: simple cache_write with compute_at - s0 = s0.cache_write(conv, "global", dag) - conv_global = conv + s0, conv_global = s0.cache_write(conv, "global", dag) conv += 1 relu += 1 add += 1 s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) # 2: simple cache_read with compute_at - s0 = s0.cache_read(kernel, "global", [conv_global], dag) - kernel_global = kernel + 1 + s0, kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) conv_global += 1 conv += 1 relu += 1 @@ -252,8 +249,7 @@ def test_state_cache_read_write(): # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory - s0 = s0.cache_read(pad_temp, "global", [conv_global], dag) - pad_temp_global = pad_temp + 1 + s0, pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) kernel_data += 1 kernel_split += 1 kernel += 1 @@ -262,8 +258,8 @@ def test_state_cache_read_write(): conv += 1 relu += 1 add += 1 - s0 = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) - pad_temp_shared = pad_temp_global + 1 + s0, pad_temp_shared = s0.cache_read( + pad_temp_global, "shared", [conv_global], dag) kernel_data += 1 kernel_split += 1 kernel += 1 @@ -279,7 +275,7 @@ def test_state_cache_read_write(): # 4: cache_read with multi readers # This stage cannot be compute at to its consumer - s0 = s0.cache_read(data, "global", [pad_temp, add], dag) + s0, data_global = s0.cache_read(data, "global", [pad_temp, add], dag) pad_temp += 1 pad_temp_global += 1 pad_temp_shared += 1 @@ -350,7 +346,7 @@ def test_state_cache_read_write(): # 5: cache_write with multi outputs # See tests/cpp/ansor_test.cc for more information - s0 = s0.cache_write(kernel_split, "global", dag) + s0, _ = s0.cache_write(kernel_split, "global", dag) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -424,40 +420,39 @@ def test_follow_split_follow_fused_split(): s0 = dag.get_init_state() C = 2 - s0 = s0.cache_write(C, "global", dag) - C_global = C + s0, C_global = s0.cache_write(C, "global", dag) C += 1 - s0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) + s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) split_step0 = s0.transform_steps_size() - 1 for level in range(1, 6): tmp = s0 - tmp = tmp.follow_split(C_global, tmp.stage( + tmp, _ = tmp.follow_split(C_global, tmp.stage( C_global).iterator(0), split_step0, level) for i in range(0, level): assert tmp.stage(C).iterator(i).range.extent == \ tmp.stage(C_global).iterator(i).range.extent - s0 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) + s0, its1 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) split_step1 = s0.transform_steps_size() - 1 - its = s0.stage(C).iterators() - s0 = s0.reorder(C, [its[0], its[5], its[1], its[6], its[2], its[7], - its[3], its[8], its[4], its[9]]) - s0 = s0.fuse(C, [s0.stage(C).iterator(0), s0.stage(C).iterator(1)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(1), s0.stage(C).iterator(2)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(2), s0.stage(C).iterator(3)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(3), s0.stage(C).iterator(4)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(4), s0.stage(C).iterator(5)]) + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0 = s0.reorder(C, its) + for i in range(0, 5): + s0, _ = s0.fuse(C, [s0.stage(C).iterator(i), + s0.stage(C).iterator(i+1)]) for level in range(0, 4): tmp = s0 - tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, False) + tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, False) assert tmp.stage(C).iterator(level+1).range.extent == \ tmp.stage(C_global).iterator(0).range.extent for level in range(0, 4): tmp = s0 - tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, True) + tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, True) assert tmp.stage(C).iterator(level+1).range.extent == \ tmp.stage(C_global).iterator(1).range.extent @@ -466,6 +461,49 @@ def test_rfactor(): pass +def test_measure_local_builder_runner(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + + s0 = dag.get_init_state() + A, B, C = 0, 1, 2 + s0, C_global = s0.cache_write(C, "global", dag) + C += 1 + s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 8, 8]) + s0, its1 = s0.split(C, s0.stage(C).iterator(4), [8, 4, 4]) + s0 = s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its0[3], its1[3]]) + s0 = s0.compute_at(C_global, C, s0.stage(C).iterator(3)) + s0, _ = s0.split(C_global, s0.stage(C_global).iterator(2), [16]) + s0, B_global = s0.cache_read(B, "global", [C_global], dag) + C += 1 + C_global += 1 + s0 = s0.compute_at(B_global, C_global, s0.stage(C_global).iterator(0)) + s0, A_global = s0.cache_read(A, "global", [C_global], dag) + B += 1 + B_global += 1 + C += 1 + C_global += 1 + s0 = s0.compute_at(A_global, C_global, s0.stage(C_global).iterator(2)) + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + local_runner = ansor.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +def test_search_basic(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + if __name__ == "__main__": test_compute_dag_basic() test_state_split_fuse_reorder() @@ -473,3 +511,5 @@ def test_rfactor(): test_state_cache_read_write() test_follow_split_follow_fused_split() test_rfactor() + test_measure_local_builder_runner() + # test_search_basic() From 6b21dc6e7318bb64827382f60ca07871860efa0a Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 4 Jun 2020 21:02:38 +0800 Subject: [PATCH 07/78] Add ansor.auto_schedule() API; First AutoSchedule working version(#8) * Add basic Python support for ansor.auto_schedule * Update AutoSchedule API * Bug fix for get the attach point of a fused iter * Update UT after infer bug fix --- python/tvm/ansor/__init__.py | 4 +- python/tvm/ansor/compute_dag.py | 2 +- python/tvm/ansor/cost_model/__init__.py | 20 +++ python/tvm/ansor/cost_model/cost_model.py | 48 ++++++ python/tvm/ansor/state.py | 6 +- python/tvm/ansor/task.py | 162 +++++++++++++++++- src/ansor/auto_schedule.cc | 85 +++++++++ src/ansor/auto_schedule.h | 61 +++++++ src/ansor/cost_model/cost_model.cc | 42 +++-- src/ansor/loop_state.cc | 21 +++ .../search_policy/meta_tile_rewrite_policy.cc | 11 +- src/ansor/search_policy/utils.h | 11 +- src/te/schedule/schedule_dataflow_rewrite.cc | 66 ++++++- tests/python/unittest/test_ansor_common.py | 40 ++++- 14 files changed, 547 insertions(+), 32 deletions(-) create mode 100644 python/tvm/ansor/cost_model/__init__.py create mode 100644 python/tvm/ansor/cost_model/cost_model.py create mode 100644 src/ansor/auto_schedule.cc create mode 100644 src/ansor/auto_schedule.h diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index cb039cf07d5f..70834ba8936f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -18,5 +18,7 @@ """Namespace for Ansor autoSchedule""" from .compute_dag import ComputeDAG -from .task import SearchTask +from .task import SearchTask, MetaTileRewritePolicy, TuneOption +from .task import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner +from .cost_model import RandomModel diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index a66a181f054c..f3d27884d622 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -52,7 +52,7 @@ def get_init_state(self): """ return _ffi_api.ComputeDAGGetInitState(self) - def apply_steps_from_state(self, state, layout_rewrite_level): + def apply_steps_from_state(self, state, layout_rewrite_level=None): """ Parameters ---------- diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py new file mode 100644 index 000000000000..aac062e964fd --- /dev/null +++ b/python/tvm/ansor/cost_model/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +""" ... """ + +from .cost_model import RandomModel diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py new file mode 100644 index 000000000000..aebc89f465a1 --- /dev/null +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ +import ctypes +import numpy as np + +import tvm._ffi +from tvm.runtime import Object + +from .. import _ffi_api + + +@tvm._ffi.register_object("ansor.CostModel") +class CostModel(Object): + pass + + +@tvm._ffi.register_object("ansor.RandomModel") +class RandomModel(Object): + """ + """ + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.RandomModel) + +# A random number generator func for c++'s RandomModel +@tvm._ffi.register_func("ansor.cost_model.random_number") +def random_number(n, return_ptr): + if n == 0: + return + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) + array_wrapper[:] = np.random.uniform(0, 1, (n,)) diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py index 7de95a8a74af..aa231ab6f4c6 100644 --- a/python/tvm/ansor/state.py +++ b/python/tvm/ansor/state.py @@ -408,9 +408,9 @@ def rfactor(self, stage_id, it, factor_iter_id, task_dag): state : State The updated state """ - state = _ffi_api.StateRfactor(self, stage_id, it, factor_iter_id, - task_dag) - return state + state, new_stage_id = _ffi_api.StateRfactor(self, stage_id, it, + factor_iter_id, task_dag) + return state, new_stage_id def storage_align(self, stage_id, it, factor, offset): """ diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py index 245cf4c727ae..5fab57c28f48 100644 --- a/python/tvm/ansor/task.py +++ b/python/tvm/ansor/task.py @@ -16,12 +16,16 @@ # under the License. # pylint: disable=unused-import """ ... """ +import random import tvm._ffi from tvm.runtime import Object +from .measure import LocalBuilder, LocalRunner +from .cost_model import RandomModel from . import _ffi_api + @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): """ @@ -37,8 +41,9 @@ class HardwareParams(Object): def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, max_innermost_split_factor): self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, - vector_unit_bytes, cache_line_bytes, max_unroll_vec, - max_innermost_split_factor) + vector_unit_bytes, cache_line_bytes, + max_unroll_vec, + max_innermost_split_factor) @tvm._ffi.register_object("ansor.SearchTask") @@ -56,4 +61,155 @@ class SearchTask(Object): def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None): self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, - workload_key, target, target_host, hardware_params) + workload_key, target, target_host, + hardware_params) + + +@tvm._ffi.register_object("ansor.SearchPolicy") +class SearchPolicy(Object): + pass + + +@tvm._ffi.register_object("ansor.MetaTileRewritePolicy") +class MetaTileRewritePolicy(Object): + """ The search policy that searches with meta tiling and random rewrite + + Parameters + ---------- + program_cost_model: CostModel + Cost model for complete programs + params: int + Parameters of the search policy, go meta_tile_rewrite_policy.h to find the + definitions. See code below to find the default values + seed: int + Random seed + """ + + def __init__(self, + program_cost_model, + params=None, + seed=None): + # set default parameters + default_params = { + "eps_greedy": 0.05, + + 'evolutionary_search_population': 2048, + 'evolutionary_search_num_iters': 15, + "evolutionary_search_mutation_prob": 0.85, + "evolutionary_search_use_measured_ratio": 0.2, + + 'cpu_multi_level_tiling_structure': 'SSRSRS', + 'gpu_multi_level_tiling_structure': 'SSSRRSRS', + + 'disable_change_compute_location': 0, + } + + if params is None: + params = default_params + else: + for key, value in default_params.items(): + if key not in params: + params[key] = value + + self.__init_handle_by_constructor__( + _ffi_api.MetaTileRewritePolicy, program_cost_model, params, + seed or random.randint(1, 1 << 30)) + + +@tvm._ffi.register_object("ansor.TuneOption") +class TuneOption(Object): + """ The options for tuning + + Parameters + ---------- + n_trials: int + Number of total measurement trials + early_stopping: int + Stops early the tuning if no improvement after n measurements + num_measure_per_iter: int + The number of programs to be measured at each iteration + verbose: int + Verbosity level. 0 means silent. + builder: Builder + Builder which builds the program + runner: Runner + Runner which runs the program and measure time costs + callbacks: List[MeasureCallback] + Callback functions + """ + + def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, + verbose=1, builder='local', runner='local', callbacks=None): + if isinstance(builder, str): + if builder == 'local': + builder = LocalBuilder() + else: + raise ValueError("Invalid builder: " + builder) + + if isinstance(runner, str): + if runner == 'local': + runner = LocalRunner() + else: + raise ValueError("Invalid builder: " + runner) + + if callbacks is None: + callbacks = [] + + self.__init_handle_by_constructor__( + _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter, + verbose, builder, runner, callbacks) + + +def auto_schedule(workload, search_policy='default', target=None, + target_host=None, hardware_params=None, + tune_option=None): + """ Do auto schedule for a compute declaration. + + The workload paramter can be a `string` as workload_key, or directly + passing a `SearchTask` as input. + + Parameters + ---------- + workload : Str or SearchTask + + target : Target + + task : SearchTask + + target_host : Target = None + + search_policy : Union[SearchPolicy, str] + + hardware_params : HardwareParams + + tune_option : TuneOption + + Returns + ------- + state : State + + sch : tvm.Schedule + + tensors : List[Tensor] + """ + if isinstance(search_policy, str): + if search_policy == 'default': + search_policy = MetaTileRewritePolicy(RandomModel()) + else: + raise ValueError("Invalid search policy: " + search_policy) + + if tune_option is None: + tune_option = TuneOption(n_trials=0) + + if isinstance(workload, str): + sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( + workload, target, target_host, search_policy, hardware_params, + tune_option) + return sch, tensors + elif isinstance(workload, SearchTask): + state = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, + tune_option) + return state + else: + raise ValueError("Invalid workload: " + workload + + ", should be String or SearchTask") diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc new file mode 100644 index 000000000000..974e7e5d9f58 --- /dev/null +++ b/src/ansor/auto_schedule.cc @@ -0,0 +1,85 @@ +#include "auto_schedule.h" + +#include + +#include +#include + +#include "search_policy/meta_tile_rewrite_policy.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(TuneOptionNode); + +TuneOption TuneOptionNode::make(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, + Builder builder, Runner runner, + Array callbacks) { + auto node = make_object(); + node->n_trials = n_trials; + node->early_stopping = early_stopping; + node->num_measure_per_iter = num_measure_per_iter; + node->verbose = verbose; + node->builder = std::move(builder); + node->runner = std::move(runner); + node->callbacks = std::move(callbacks); + return TuneOption(node); +} + +State AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + // Search for the best schedule + ProgramMeasurer measurer = + ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, + tune_option->callbacks, tune_option->verbose); + + return search_policy->Search( + task, tune_option->n_trials, tune_option->early_stopping, + tune_option->num_measure_per_iter, tune_option->verbose, measurer); +} + +std::pair > AutoSchedule( + std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option) { + ComputeDAG dag = ComputeDAGNode::make_by_workload_key(workload_key); + SearchTask task = SearchTaskNode::make( + std::move(dag), std::move(workload_key), std::move(target), + std::move(target_host), std::move(hardware_params)); + State state = AutoSchedule(std::move(task), std::move(search_policy), + std::move(tune_option)); + + return task->compute_dag.ApplySteps(state->transform_steps); +} + +TVM_REGISTER_GLOBAL("ansor.TuneOption") + .set_body_typed([](int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array callbacks) { + return TuneOptionNode::make(n_trials, early_stopping, + num_measure_per_iter, verbose, builder, + runner, callbacks); + }); + +TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") + .set_body_typed([](SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + return AutoSchedule(task, search_policy, tune_option); + }); + +TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") + .set_body_typed([](std::string workload_key, Target target, + Target target_host, SearchPolicy search_policy, + HardwareParams hardware_params, TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = + AutoSchedule(workload_key, target, target_host, search_policy, + hardware_params, tune_option); + + return Array{sch, return_tensors}; + }); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h new file mode 100644 index 000000000000..c354751390fe --- /dev/null +++ b/src/ansor/auto_schedule.h @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Meta information for a search task + */ + +#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ +#define TVM_ANSOR_AUTO_SCHEDULE_H_ + +#include "measure.h" + +namespace tvm { +namespace ansor { + +/*! \brief Tuning and measurement options */ +class TuneOption; +class TuneOptionNode : public Object { + public: + int n_trials; // Number of total measurement trials + int early_stopping; // Stops early the tuning if no improvement after n + // measurements + int num_measure_per_iter; // The number of programs to be measured at each + // iteration + int verbose; // Verbosity level. 0 means silent. + Builder builder; // Builder which builds the program + Runner runner; // Runner which runs the program and measure time + // costs + Array callbacks; // Callback functions + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("n_trials", &n_trials); + v->Visit("early_stopping", &early_stopping); + v->Visit("num_measure_per_iter", &num_measure_per_iter); + v->Visit("verbose", &verbose); + v->Visit("builder", &builder); + v->Visit("runner", &runner); + v->Visit("callbacks", &callbacks); + } + + static TuneOption make(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array callbacks); + + static constexpr const char* _type_key = "ansor.TuneOption"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(TuneOption, ObjectRef, TuneOptionNode); + +/*! \brief Auto schedule for a compute declaration */ +State AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuneOption tune_option); + +std::pair > AutoSchedule( + std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ \ No newline at end of file diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index d4304bccb4bf..060d2b703287 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -2,8 +2,10 @@ * Copyright (c) 2020 by Contributors */ #include "cost_model.h" -#include + #include +#include + #include namespace tvm { @@ -39,8 +41,7 @@ CostModel RandomModelNode::make() { } void RandomModelNode::Update(const Array& inputs, - const Array& results) { -} + const Array& results) {} void RandomModelNode::Predict(const SearchTask& task, const std::vector& states, @@ -51,14 +52,13 @@ void RandomModelNode::Predict(const SearchTask& task, CostModel MeasureModelNode::make(Builder builder, Runner runner) { ObjectPtr node = make_object(); - node->measurer = ProgramMeasurerNode::make(std::move(builder), std::move(runner), - Array(), 0); + node->measurer = ProgramMeasurerNode::make( + std::move(builder), std::move(runner), Array(), 0); return CostModel(node); } void MeasureModelNode::Update(const Array& inputs, - const Array& results) { -} + const Array& results) {} void MeasureModelNode::Predict(const SearchTask& task, const std::vector& states, @@ -66,7 +66,8 @@ void MeasureModelNode::Predict(const SearchTask& task, std::vector inputs; std::vector results; - inputs.clear(); inputs.reserve(states.size()); + inputs.clear(); + inputs.reserve(states.size()); for (const auto& state : states) { inputs.push_back(MeasureInputNode::make(task, state)); } @@ -79,7 +80,8 @@ void MeasureModelNode::Predict(const SearchTask& task, } } -CostModel PythonBasedCostModelNode::make(PackedFunc update_func, PackedFunc predict_func, +CostModel PythonBasedCostModelNode::make(PackedFunc update_func, + PackedFunc predict_func, PackedFunc predict_stage_func) { auto node = make_object(); node->update_func = std::move(update_func); @@ -89,7 +91,7 @@ CostModel PythonBasedCostModelNode::make(PackedFunc update_func, PackedFunc pred } void PythonBasedCostModelNode::Update(const Array& inputs, - const Array& results) { + const Array& results) { update_func(inputs, results); } @@ -101,14 +103,15 @@ void PythonBasedCostModelNode::Predict(const SearchTask& task, static_cast(scores->data())); } -void PythonBasedCostModelNode::PredictStages(const SearchTask& task, - const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { +void PythonBasedCostModelNode::PredictStages( + const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { int n_states = states.size(); int n_stages = task->compute_dag.GetInitState()->stages.size(); std::vector flatten_scores; - flatten_scores.resize(n_states * n_stages * 2); // Allocate sufficient spaces. + // Allocate sufficient spaces. + flatten_scores.resize(n_states * n_stages * 2); predict_stage_func(task, Array(states.begin(), states.end()), static_cast(flatten_scores.data())); @@ -134,8 +137,9 @@ void PythonBasedCostModelNode::PredictStages(const SearchTask& task, int offset = 0; if ((*state_scores)[i] > -INFINITY) { - // If the score is valid. Copy scored stages and assign 0 to placeholder and inlined stages. - // If the score is 0, meaning this state failed to be lowered. Just bypass to update offset. + // If the score is valid. Copy scored stages and assign 0 to placeholder + // and inlined stages. If the score is 0, meaning this state failed to + // be lowered. Just bypass to update offset. for (const Stage& stage : states[i]->stages) { if (stage->op_type == kPlaceholder) { scores.push_back(0); @@ -159,5 +163,9 @@ void PythonBasedCostModelNode::PredictStages(const SearchTask& task, } } +TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { + return RandomModelNode::make(); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index e18d36e34581..32940da0773a 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1142,5 +1142,26 @@ TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") return Array{state, IntImm(DataType::Int(32), res)}; }); +TVM_REGISTER_GLOBAL("ansor.StatePragma") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const std::string& pragma_type) { + state.pragma(stage_id, it, pragma_type); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateRfactor") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int factor_iter_id, const ComputeDAG& task_dag) { + int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; + }); + +TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int factor, int offset) { + state.storage_align(stage_id, it, factor, offset); + return state; + }); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index b3b93ec9c839..b4501804607a 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -3,6 +3,7 @@ */ #include "meta_tile_rewrite_policy.h" +#include #include #include #include @@ -586,7 +587,8 @@ class RuleAddRfactor : public StructureSynthesisRule { } }; -void MetaTileRewritePolicyNode::SynthesizeMetaStructure(std::vector* out_states) { +void MetaTileRewritePolicyNode::SynthesizeMetaStructure( + std::vector* out_states) { State init_state = cur_task_->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = GetStringParam(params, "cpu_multi_level_tiling_structure"); @@ -1416,5 +1418,12 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( << std::fixed << std::setprecision(2) << duration << std::endl; } +TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") +.set_body_typed([](CostModel program_cost_model, + Map params, + int seed){ + return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 05b50775b52d..3337975d7a88 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -50,10 +50,15 @@ inline double GetDoubleParam(const Map& attr_dict, // Get a string from a tvm str Map inline std::string GetStringParam(const Map& attr_dict, const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pstr = attr_dict[key].as(); + CHECK_GT(attr_dict.count(key), 0) + << "Cannot find key: \"" << key << "\" in " << attr_dict; + const auto& target = attr_dict[key]; + if (auto pstr = target.as()) { + return pstr->value; + } + auto pstr = target.as(); CHECK(pstr != nullptr); - return pstr->value; + return pstr->data; } // Get a iterator name set from a tvm str Map diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index af72d3b1a1df..04a3f0b25bee 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -461,7 +461,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); auto it = s->iter_var_attrs.find(iv); - // don;t need to rebase path that are binded. + // don't need to rebase path that are binded. if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } @@ -614,10 +614,74 @@ void InjectInline(ScheduleNode* sch) { } } +void LegalizeInvalidAttach(ScheduleNode* sch) { + std::unordered_map replace_map; + + for (Stage stage : sch->stages) { + for (Stage s = stage; s.defined();) { + Stage spec = s.GetAttachSpec(); + if (spec->attach_type != kScope) { + break; + } + bool start_attach = false; + IterVar attach_ivar = spec->attach_ivar; + s = spec->attach_stage; + CHECK(attach_ivar.defined()); + CHECK(s.defined()); + + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (!start_attach && iv.same_as(attach_ivar)) { + start_attach = true; + } + } + if (!start_attach) { + // If the attach_var is fused into another iter_var, update the + // attach_var to be the fused one + // Do this recursively. + IterVar new_attach_ivar = attach_ivar;; + bool updated = true; + while (updated) { + updated = false; + for (const auto& rel : s->relations) { + if (const FuseNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->inner)) { + new_attach_ivar = r->fused; + updated = true; + } + } else if (const SplitNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->parent)) { + new_attach_ivar = r->inner; + updated = true; + } + } + } + replace_map[attach_ivar] = new_attach_ivar; + } + } + } + } + + // remap the parent relation + for (Stage s : sch->stages) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } + for (Stage s : sch->groups) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } +} + Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); RebaseNonZeroMinLoop(sn); + LegalizeInvalidAttach(sn.operator->()); return sn; } diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index da87ea5fe9cf..8f04d003ff94 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import random +import numpy as np + import tvm from tvm import te from tvm import ansor @@ -499,10 +502,43 @@ def test_measure_local_builder_runner(): def test_search_basic(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + print("Test schedule search with default search policy") + + N = 128 + A, B, C = matmul_nkkm(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) + cost_model = ansor.RandomModel() + # seed = random.randint(1, 1 << 30) + seed = 944563397 + search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + state = ansor.auto_schedule(task, search_policy, + tune_option=ansor.TuneOption(n_trials=2)) + sch, args = dag.apply_steps_from_state(state) + + print("==== Get State ====") + print(state) + print("==== Get Python Code ====") + print(dag.print_python_code_from_state(state)) + + try: + print("==== Get Lowered Stmt ====") + print(tvm.lower(sch, args, simple_mode=True)) + mod = tvm.build(sch, args, tgt) + + ctx = tvm.context("llvm", 0) + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + print("==== Verification passed ====") + except Exception: + raise Exception("Error encounterd with seed: %d" % (seed)) + if __name__ == "__main__": test_compute_dag_basic() @@ -512,4 +548,4 @@ def test_search_basic(): test_follow_split_follow_fused_split() test_rfactor() test_measure_local_builder_runner() - # test_search_basic() + test_search_basic() From e52135f37418d86ca547ff090fdbd47cca38e706 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 5 Jun 2020 17:25:31 +0800 Subject: [PATCH 08/78] Bug fix & Add python serialization API (#10) * Delete C++ UT hack since Python is ready * Add ndarray.non_empty * Update Serialization python API --- include/tvm/runtime/c_runtime_api.h | 23 +++ include/tvm/runtime/ndarray.h | 12 +- python/tvm/ansor/__init__.py | 1 + python/tvm/ansor/compute_dag.py | 12 ++ python/tvm/ansor/measure.py | 8 +- python/tvm/ansor/serialization.py | 98 ++++++++++++ python/tvm/runtime/ndarray.py | 33 ++++ src/ansor/compute_dag.cc | 5 + .../search_policy/meta_tile_rewrite_policy.h | 71 +++++---- src/ansor/serialization.cc | 143 +++++++++--------- src/ansor/serialization.h | 31 ++-- src/runtime/ndarray.cc | 80 +++++++++- tests/cpp/ansor_test.cc | 45 ------ tests/python/unittest/test_ansor_common.py | 18 ++- 14 files changed, 408 insertions(+), 172 deletions(-) create mode 100644 python/tvm/ansor/serialization.py diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 213c7059a5f9..5a32ac7d3d9f 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -384,6 +384,29 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); +/*! + * \brief Allocate a nd-array's memory of non-empty values, + * including space of shape, of given spec. + * + * \param shape The shape of the array, the data content will be copied to out + * \param ndim The number of dimension of the array. + * \param dtype_code The type code of the dtype + * \param dtype_bits The number of bits of dtype + * \param dtype_lanes The number of lanes in the dtype. + * \param device_type The device type of context + * \param device_id The device id of context. + * \param out The output handle. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMArrayAllocNonEmpty(const tvm_index_t* shape, + int ndim, + int dtype_code, + int dtype_bits, + int dtype_lanes, + int device_type, + int device_id, + TVMArrayHandle* out); + /*! * \brief Free the TVM Array. * \param handle The array handle to be freed. diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index e69d802652fd..9cc66a371974 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -138,7 +138,17 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, + DLDataType dtype, DLContext ctx); + /*! + * \brief Create an NDArray with non-empty values. + * \param shape The shape of the new array. + * \param dtype The data type of the new array. + * \param ctx The context of the Array. + * \return The created Array + */ + TVM_DLL static NDArray NonEmpty(std::vector shape, + DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 70834ba8936f..af1ca24cd4dc 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -22,3 +22,4 @@ from .task import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner from .cost_model import RandomModel +from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index f3d27884d622..aa50864548a4 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -78,3 +78,15 @@ def print_python_code_from_state(self, state): str : Str """ return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) + + def infer_bound_from_state(self, state): + """ + Parameters + ---------- + state : State + + Returns + ------- + state : State + """ + return _ffi_api.ComputeDAGInferBoundFromState(self, state) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 72dd3cbfcf92..d7d0e64eb14b 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -44,6 +44,10 @@ logger = logging.getLogger('ansor') +@tvm._ffi.register_object("ansor.MeasureCallback") +class MeasureCallback(Object): + pass + @tvm._ffi.register_object("ansor.MeasureInput") class MeasureInput(Object): """ @@ -332,7 +336,7 @@ def timed_func(): if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() @@ -390,7 +394,7 @@ def timed_func(inp, build_res): if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py new file mode 100644 index 000000000000..172405ce7ddb --- /dev/null +++ b/python/tvm/ansor/serialization.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ +import numpy as np + +import tvm._ffi +from tvm.runtime import Object + +from .measure import MeasureCallback, MeasureErrorNo + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.LogToFile") +class LogToFile(MeasureCallback): + """ + Parameters + ---------- + filename : Str + """ + + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) + + +@tvm._ffi.register_object("ansor.LogReader") +class LogReader(Object): + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) + + def read_lines(self, max_size=-1, skip_size=0): + inputs, results = _ffi_api.LogReaderReadLines( + self, max_size, skip_size) + return inputs, results + + def __iter__(self): + while True: + ret = _ffi_api.LogReaderReadNext(self) + if ret is None or not len(ret): + break + yield ret[0], ret[1] # (input, result) + + +def best_measure_pair_in_file(filename, workload_key=None, target=None): + """ Return best results form log file + + Parameters + ---------- + filename : Str + + workload_key : Str + + target : Str + + Returns + ------- + inp : MeasureInput + + res : MeasureResult + """ + log_reader = LogReader(filename) + best_cost = 1e30 + best_inp = None + best_res = None + + for inp, res in log_reader: + if res.error_no != MeasureErrorNo.NO_ERROR: + continue + if workload_key and inp.task.workload_key != workload_key: + continue + if target and inp.task.target.target_name != target.target_name: + continue + + costs = [] + for value in res.costs: + costs.append(value.value) + cost = np.mean(costs) + if cost < best_cost: + best_cost = cost + best_inp = inp + best_res = res + + return best_inp, best_res diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 060673dc19c6..967bfcdd3cde 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -279,6 +279,39 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): return _make_array(handle, False, False) +def non_empty(shape, dtype="float32", ctx=context(1, 0)): + """Create an non-empty array given shape and device + + Parameters + ---------- + shape : tuple of int + The shape of the array + + dtype : type or str + The data type of the array. + + ctx : TVMContext + The context of the array + + Returns + ------- + arr : tvm.nd.NDArray + The array tvm supported. + """ + shape = c_array(tvm_shape_index_t, shape) + ndim = ctypes.c_int(len(shape)) + handle = TVMArrayHandle() + dtype = DataType(dtype) + check_call(_LIB.TVMArrayAllocNonEmpty( + shape, ndim, + ctypes.c_int(dtype.type_code), + ctypes.c_int(dtype.bits), + ctypes.c_int(dtype.lanes), + ctx.device_type, + ctx.device_id, + ctypes.byref(handle))) + return _make_array(handle, False, False) + def from_dlpack(dltensor): """Produce an array from a DLPack tensor without memory copy. Retreives the underlying DLPack tensor's pointer to create an array from the diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index c9415a70c303..7fad0ce5b28a 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1271,5 +1271,10 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") return dag.PrintStepsAsPython(state->transform_steps); }); +TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.ReplayAndInferBound(state->transform_steps); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 56a75f8e52fe..ca9033ad866e 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -1,91 +1,100 @@ /*! * Copyright (c) 2020 by Contributors * \file ansor/meta_tile_rewrite_policy.h - * \brief A search policy that search with meta tiling structure and random rewrite + * \brief A search policy that search with meta tiling structure and random + * rewrite */ #ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#include +#include #include -#include #include -#include -#include "search_policy.h" +#include +#include + #include "../cost_model/cost_model.h" #include "../utils.h" - +#include "search_policy.h" namespace tvm { namespace ansor { /*! Multi stage search policy */ -class MetaTileRewritePolicyNode: public SearchPolicyNode { +class MetaTileRewritePolicyNode : public SearchPolicyNode { public: CostModel program_cost_model; /* this->params is used to store the following arguments - * int evolutionary_search_population // The population size for evolutionary search - * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search - * int evolutionary_search_num_iters; // The number of iterations for evolutionary search - * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial - * // population for evolutionary search - * double eps_greedy; // Always allocate this percentage of measurements to random sampled states - * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU - * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU + * int evolutionary_search_population + * The population size for evolutionary search + * int evolutionary_search_mutation_prob + * The probability of mutation for evolutionary search + * int evolutionary_search_num_iters + * The number of iterations for evolutionary search + * double local_mutation_use_measured_ratio + * The maximum percentage of measured states in the initial population + * for evolutionary search + * double eps_greedy + * Always allocate this percentage of measurements to random sampled states + * str cpu_multi_level_tiling_structure + * The structure of multi-level tiling for CPU + * str gpu_multi_level_tiling_structure + * The structure of multi-level tiling for GPU */ Map params; static SearchPolicy make(CostModel program_cost_model, - Map params, - int seed); + Map params, int seed); // Search and make n_trails measurements // Return the best state - State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) final; + State Search(SearchTask task, int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, + ProgramMeasurer measurer) final; // Continue search. This is used by JointTuner std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + SearchTask task, int num_measure, int verbose, + ProgramMeasurer measurer) final; - static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char* _type_key = "ansor.MetaTileRewritePolicy"; static const std::vector auto_unroll_configs; TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); - SearchTask cur_task_; // The current task + SearchTask cur_task_; // The current task - friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, - const std::vector& random_states, int remaining_n_trials); + const std::vector& random_states, + int remaining_n_trials); private: // Run one round of the search pipeline - void SearchOneRound(std::vector* best_states, - int num_random_states, std::vector* random_states); + void SearchOneRound(std::vector* best_states, int num_random_states, + std::vector* random_states); // Synthesize meta tiling structure without tile size void SynthesizeMetaStructure(std::vector* out_states); // Sample init population void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states); + int out_size, std::vector* out_states); // Perform evolutionary search void EvolutionarySearch(const std::vector& init_population, - int num_best_states, std::vector* best_states); + int num_best_states, std::vector* best_states); SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator int verbose_; // Verbose level (0 means silent) - int num_measure_per_iter_; // The number of states to measure per iteration + int num_measure_per_iter_; // The number of states to measure per iteration - // The set of the already measured states. We store the string format for redundancy check + // The set of the already measured states. We store the string format for + // redundancy check std::unordered_set measured_states_set_; // The array of already measured states. diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 0e2b0be42587..fc4917409cc0 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -1,15 +1,17 @@ /*! * Copyright (c) 2020 by Contributors */ +#include "serialization.h" + #include -// #include #include + #include #include -#include #include #include -#include "serialization.h" +#include + #include "loop_state.h" #include "utils.h" @@ -18,10 +20,10 @@ namespace dmlc { namespace json { -inline std::vector& FloatArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::PrimExpr>& data) { +inline std::vector& FloatArrayToVector( + std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto&x : data) { + for (const auto& x : data) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; out->push_back(pf->value); @@ -29,10 +31,10 @@ inline std::vector& FloatArrayToVector(std::vector* out, return *out; } -inline std::vector& IntArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::PrimExpr>& data) { +inline std::vector& IntArrayToVector( + std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto&x : data) { + for (const auto& x : data) { auto pi = x.as<::tvm::tir::IntImmNode>(); CHECK(pi != nullptr) << "Cost can only contain int values"; out->push_back(pi->value); @@ -41,15 +43,15 @@ inline std::vector& IntArrayToVector(std::vector* out, } template <> -struct Handler > { +struct Handler> { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Stage> & data) { + const std::vector<::tvm::ansor::Stage>& data) { // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Stage> * data) { + std::vector<::tvm::ansor::Stage>* data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); @@ -57,9 +59,9 @@ struct Handler > { }; template <> -struct Handler > { +struct Handler> { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Step> & data) { + const std::vector<::tvm::ansor::Step>& data) { std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { @@ -92,7 +94,8 @@ struct Handler > { writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->src_step_id); writer->WriteArrayItem(ps->n_split); - } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { + } else if (auto ps = + data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { writer->WriteArrayItem(std::string("FFSS")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); @@ -165,7 +168,7 @@ struct Handler > { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Step> * data) { + std::vector<::tvm::ansor::Step>* data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; std::string name, scope_name, pragma_type; @@ -183,7 +186,8 @@ struct Handler > { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + data->push_back( + ::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); } else if (name == "SS") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -236,8 +240,8 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, - iter_id, ::tvm::ansor::IteratorAnnotation(ann))); + data->push_back(::tvm::ansor::AnnotationStepNode::make( + stage_id, iter_id, ::tvm::ansor::IteratorAnnotation(ann))); } else if (name == "CA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -269,8 +273,8 @@ struct Handler > { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&scope_name); - data->push_back(::tvm::ansor::CacheWriteStepNode::make( - stage_id, scope_name)); + data->push_back( + ::tvm::ansor::CacheWriteStepNode::make(stage_id, scope_name)); } else if (name == "PS") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -278,8 +282,8 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&pragma_type); - data->push_back(::tvm::ansor::PragmaStepNode::make( - stage_id, iter_id, pragma_type)); + data->push_back( + ::tvm::ansor::PragmaStepNode::make(stage_id, iter_id, pragma_type)); } else if (name == "RFS") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -287,8 +291,8 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStepNode::make( - stage_id, iter_id, factor_iter_id)); + data->push_back(::tvm::ansor::RfactorStepNode::make(stage_id, iter_id, + factor_iter_id)); } else if (name == "SA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -388,7 +392,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { writer->BeginArray(false); writer->WriteArraySeperator(); writer->BeginArray(false); - for (const auto&x : data.costs) { + for (const auto& x : data.costs) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; writer->WriteArrayItem(pf->value); @@ -430,7 +434,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); @@ -438,8 +442,7 @@ MeasureCallback LogToFileNode::make(std::string filename) { return MeasureCallback(node); } -void WriteMeasureRecords(std::ostream* os, - const Array& inputs, +void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { @@ -452,10 +455,8 @@ void WriteMeasureRecords(std::ostream* os, } } -void ReadMeasureRecords(std::string str, - MeasureInputNode* inp, - MeasureResultNode* res, - std::string* log_version) { +void ReadMeasureRecords(std::string str, MeasureInputNode* inp, + MeasureResultNode* res, std::string* log_version) { std::istringstream ss(str); dmlc::JSONReader reader(&ss); std::string key; @@ -474,15 +475,6 @@ void ReadMeasureRecords(std::string str, } } -TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); -}); - void LogToFileNode::callback(const SearchPolicy& policy, const Array& inputs, const Array& results) { @@ -518,8 +510,8 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { return false; } -std::pair, Array > LogReaderNode::ReadLines( - int max_size, int skip_size) { +std::pair, Array> LogReaderNode::ReadLines( + int max_size, int skip_size) { auto inp = make_object(); auto res = make_object(); Array inputs; @@ -542,32 +534,41 @@ std::pair, Array > LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target) { - std::pair best_pair; - double best_cost = 1e30; - - auto inp = make_object(); - auto res = make_object(); - LogReader reader = LogReaderNode::make(filename); - - while (reader->ReadNext(inp.get(), res.get())) { - if (res->error_no != kNoError || inp->task->workload_key != workload_key - || inp->task->target->target_name != target->target_name) { - continue; - } - - double cost = FloatArrayMean(res->costs); - - if (cost < best_cost) { - best_cost = cost; - best_pair = std::make_pair(inp->copy(), res->copy()); - } - } - - return best_pair; -} +TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); + }); + +TVM_REGISTER_GLOBAL("ansor.LogToFile") + .set_body_typed([](const std::string& filename) { + return LogToFileNode::make(filename); + }); + +TVM_REGISTER_GLOBAL("ansor.LogReader") + .set_body_typed([](const std::string& filename) { + return LogReaderNode::make(filename); + }); + +TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") + .set_body_typed([](LogReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; + }); + +TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") + .set_body_typed([](LogReader reader) { + auto inp = make_object(); + auto res = make_object(); + if (reader->ReadNext(inp.get(), res.get())) { + return Array{ObjectRef(inp), ObjectRef(res)}; + } else { + return Array(); + } + }); } // namespace ansor -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index 96dfb0ee320b..ef4132169652 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -7,11 +7,11 @@ #ifndef TVM_ANSOR_SERIALIZATION_H_ #define TVM_ANSOR_SERIALIZATION_H_ -#include #include +#include #include + #include "measure.h" -// #include "search_policy/search_policy.h" namespace tvm { namespace ansor { @@ -19,23 +19,22 @@ namespace ansor { class LogReader; /*! \brief Log the input and results of measurments to file */ -class LogToFileNode: public MeasureCallbackNode { +class LogToFileNode : public MeasureCallbackNode { public: std::string filename; static MeasureCallback make(std::string filename); /*! \brief Log measure pairs to file. This is called by the search policy */ - void callback(const SearchPolicy& policy, - const Array& inputs, + void callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; - static constexpr const char *_type_key = "ansor.LogToFile"; + static constexpr const char* _type_key = "ansor.LogToFile"; TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; /*! \brief Log reader */ -class LogReaderNode: public Object { +class LogReaderNode : public Object { public: std::string filename; std::ifstream infile; @@ -50,27 +49,25 @@ class LogReaderNode: public Object { * \param max_size The maximum number of lines. -1 means read all lines * \param skip_size Skip the first n lines */ std::pair, Array > ReadLines( - int max_size = -1, int skip_size = 0); + int max_size = -1, int skip_size = 0); static constexpr const char* _type_key = "ansor.LogReader"; TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); + private: std::string cur_line; }; TVM_DEFINE_MUTABLE_NODE_REF(LogReader, LogReaderNode); -void WriteMeasureRecords(std::ostream* os, - const Array& inputs, +void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results); -void ReadMeasureRecords(std::string str, - MeasureInputNode* inp, - MeasureResultNode* res, - std::string* log_version); +void ReadMeasureRecords(std::string str, MeasureInputNode* inp, + MeasureResultNode* res, std::string* log_version); -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target); +std::pair BestMeasurePairInFile( + const std::string& filename, const std::string& workload_key, + const Target& target); } // namespace ansor } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 800a9167dadc..714535ecc8a6 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -26,6 +26,9 @@ #include #include +#include +#include + #include "runtime_base.h" extern "C" { @@ -180,7 +183,8 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, + DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); @@ -190,6 +194,59 @@ NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext c return ret; } + +NDArray NDArray::NonEmpty(std::vector shape, DLDataType dtype, + DLContext ctx) { + NDArray ret = Internal::Create(shape, dtype, ctx); + NDArray dummy_cpu_arr = Internal::Create(shape, dtype, {kDLCPU, 0}); + + // setup memory content + size_t size = GetDataSize(ret.get_mutable()->dl_tensor); + size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); + dummy_cpu_arr.get_mutable()->dl_tensor.data = + DeviceAPI::Get(dummy_cpu_arr->ctx)->AllocDataSpace( + {kDLCPU, 0}, size, alignment, dummy_cpu_arr->dtype); + size_t elem_cnt = 1; + for (tvm_index_t i = 0; i < dummy_cpu_arr->ndim; ++i) { + elem_cnt *= static_cast(dummy_cpu_arr->shape[i]); + } + + // TODO(..): maybe we could have better solution for assigning values + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(1.0, 10.0); + // Use float representation could make us work well on float / int type too. + for (size_t i = 0; i < elem_cnt; ++i) { + if (dummy_cpu_arr->dtype.bits == 1) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else if (dummy_cpu_arr->dtype.bits == 8) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else if (dummy_cpu_arr->dtype.bits == 16) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = + __truncXfYf2__( + static_cast(dis(gen))); + } else if (dummy_cpu_arr->dtype.bits == 32) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else if (dummy_cpu_arr->dtype.bits == 64) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else { + LOG(FATAL) << "Doesn't support dtype code " << dtype.code + << " dtype bits " << dtype.bits; + } + } + ret.get_mutable()->dl_tensor.data = + DeviceAPI::Get(ret->ctx)->AllocDataSpace( + ret->ctx, size, alignment, ret->dtype); + CopyFromTo(&(dummy_cpu_arr.get_mutable()->dl_tensor), + &(ret.get_mutable()->dl_tensor)); + return ret; +} + NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); // construct header @@ -257,8 +314,9 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, - int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, + int dtype_bits, int dtype_lanes, int device_type, + int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -272,6 +330,22 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ API_END(); } +int TVMArrayAllocNonEmpty(const tvm_index_t* shape, int ndim, int dtype_code, + int dtype_bits, int dtype_lanes, int device_type, + int device_id, TVMArrayHandle* out) { + API_BEGIN(); + DLDataType dtype; + dtype.code = static_cast(dtype_code); + dtype.bits = static_cast(dtype_bits); + dtype.lanes = static_cast(dtype_lanes); + DLContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + *out = NDArray::Internal::MoveToFFIHandle( + NDArray::NonEmpty(std::vector(shape, shape + ndim), dtype, ctx)); + API_END(); +} + int TVMArrayFree(TVMArrayHandle handle) { API_BEGIN(); NDArray::Internal::FFIDecRef(handle); diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index e5a2c98c02a9..00e748204fde 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -730,51 +730,6 @@ TEST(Feature, ExtractionMatmul) { // TODO(...): Add feature check here } -namespace tvm { -namespace ansor { -class MetaTileRewritePolicyNodeTest { - public: - MetaTileRewritePolicyNodeTest(CostModel cost_model, SearchTask task) { - policy = make_object(); - policy->program_cost_model = std::move(cost_model); - policy->rand_gen_ = std::mt19937(0); - policy->params.Set("cpu_multi_level_tiling_structure", - te::StringImmNode::make("SSRSRS")); - policy->params.Set("disable_change_compute_location", - IntImm(DataType::Int(32), 0)); - policy->cur_task_ = task; - } - void SynthesizeMetaStructure(std::vector* meta_structures) { - policy->SynthesizeMetaStructure(meta_structures); - } - void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states) { - policy->SampleInitPopulation(meta_structures, out_size, out_states); - } - tvm::runtime::ObjectPtr policy; -}; -} // namespace ansor -} // namespace tvm - -TEST(MetaTileRewritePolicy, Basic) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - const auto& task = SearchTaskNode::make( - dag, "test", tvm::target::llvm(), tvm::target::llvm(), HardwareParams()); - const auto& cost_model = RandomModelNode::make(); - MetaTileRewritePolicyNodeTest test(cost_model, task); - - std::vector meta_structures, init_population; - test.SynthesizeMetaStructure(&meta_structures); - CHECK_GE(meta_structures.size(), 0); - LOG(INFO) << "SynthesizeMetaStructure get " << meta_structures.size() - << " states."; - test.SampleInitPopulation(meta_structures, 100, &init_population); - CHECK_GE(init_population.size(), 0); - LOG(INFO) << "SampleInitPopulation get " << init_population.size() - << " states."; -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 8f04d003ff94..d701ef5b7bbd 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import random +import os import numpy as np import tvm @@ -510,12 +511,17 @@ def test_search_basic(): tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) - cost_model = ansor.RandomModel() # seed = random.randint(1, 1 << 30) seed = 944563397 + log_file = "/tmp/_ansor_python_ut_test.json" + + random.seed(seed) + cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + tune_option = ansor.TuneOption(n_trials=2, + callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, - tune_option=ansor.TuneOption(n_trials=2)) + tune_option=tune_option) sch, args = dag.apply_steps_from_state(state) print("==== Get State ====") @@ -539,6 +545,14 @@ def test_search_basic(): except Exception: raise Exception("Error encounterd with seed: %d" % (seed)) + inp, res = ansor.best_measure_pair_in_file(log_file) + s0 = dag.infer_bound_from_state(state) + s1 = dag.infer_bound_from_state(inp.state) + assert str(s0) == str(s1) + + if os.path.isfile(log_file): + os.system("rm -rf %s" % log_file) + if __name__ == "__main__": test_compute_dag_basic() From 1fe663878d00fa490b2f2b4ce2a200882aec0317 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 7 Jun 2020 08:31:54 -0700 Subject: [PATCH 09/78] Improve code style, python wrapper and test cases (#11) * Update c++ code style and unit test * Update python State wrapper and test cases --- python/tvm/ansor/__init__.py | 8 + python/tvm/ansor/_ffi_api.py | 3 +- python/tvm/ansor/compute_dag.py | 18 +- python/tvm/ansor/cost_model/__init__.py | 2 +- python/tvm/ansor/cost_model/cost_model.py | 8 +- python/tvm/ansor/loop_state.py | 439 +++++++++++++ python/tvm/ansor/measure.py | 3 +- python/tvm/ansor/serialization.py | 24 +- python/tvm/ansor/state.py | 430 ------------- python/tvm/ansor/task.py | 7 +- python/tvm/ansor/utils.py | 20 +- src/ansor/auto_schedule.cc | 69 +- src/ansor/auto_schedule.h | 30 +- src/ansor/compute_dag.cc | 78 ++- src/ansor/compute_dag.h | 53 +- src/ansor/cost_model/cost_model.cc | 27 +- src/ansor/cost_model/cost_model.h | 42 +- src/ansor/expr_hasher.h | 97 --- src/ansor/feature.cc | 3 +- src/ansor/loop_state.cc | 337 +++++----- src/ansor/loop_state.h | 158 ++++- src/ansor/measure.cc | 138 ++-- src/ansor/measure.h | 69 +- .../search_policy/meta_tile_rewrite_policy.cc | 28 +- .../search_policy/meta_tile_rewrite_policy.h | 94 +-- src/ansor/search_policy/search_policy.cc | 23 +- src/ansor/search_policy/search_policy.h | 27 +- src/ansor/search_policy/utils.cc | 56 +- src/ansor/search_policy/utils.h | 93 ++- src/ansor/search_task.cc | 51 +- src/ansor/search_task.h | 39 +- src/ansor/serialization.cc | 219 ++++--- src/ansor/serialization.h | 56 +- src/ansor/transform_step.cc | 78 +-- src/ansor/transform_step.h | 272 +++----- src/ansor/utils.cc | 23 +- src/ansor/utils.h | 133 ++-- tests/cpp/ansor_test.cc | 597 +----------------- tests/python/unittest/test_ansor_common.py | 515 +-------------- .../python/unittest/test_ansor_compute_dag.py | 66 ++ .../python/unittest/test_ansor_loop_state.py | 475 ++++++++++++++ tests/python/unittest/test_ansor_measure.py | 67 ++ .../unittest/test_ansor_search_policy.py | 81 +++ 43 files changed, 2461 insertions(+), 2595 deletions(-) create mode 100644 python/tvm/ansor/loop_state.py delete mode 100644 python/tvm/ansor/state.py delete mode 100644 src/ansor/expr_hasher.h create mode 100644 tests/python/unittest/test_ansor_compute_dag.py create mode 100644 tests/python/unittest/test_ansor_loop_state.py create mode 100644 tests/python/unittest/test_ansor_measure.py create mode 100644 tests/python/unittest/test_ansor_search_policy.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index af1ca24cd4dc..1be7ed404c17 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -17,6 +17,14 @@ # pylint: disable=unused-import, redefined-builtin """Namespace for Ansor autoSchedule""" +from . import compute_dag +from . import measure +from . import serialization +from . import loop_state +from . import task +from . import utils + +# Shortcut from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py index 177299e67d21..e7b8a59eb83b 100644 --- a/python/tvm/ansor/_ffi_api.py +++ b/python/tvm/ansor/_ffi_api.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.ansor""" + +"""Register FFI APIs from C++ for the namespace tvm.ansor""" import tvm._ffi diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index aa50864548a4..0b51ebb402cc 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -14,14 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +""" Computational graph and its analysis tools """ import tvm._ffi from tvm.runtime import Object - -from .state import State - +from .loop_state import State from . import _ffi_api @@ -50,13 +48,13 @@ def get_init_state(self): ------- state : State """ - return _ffi_api.ComputeDAGGetInitState(self) + return State(_ffi_api.ComputeDAGGetInitState(self)) def apply_steps_from_state(self, state, layout_rewrite_level=None): """ Parameters ---------- - state : State + state : StateObject layout_rewrite_level : LayoutRewriteLevel(***) Returns @@ -71,7 +69,7 @@ def print_python_code_from_state(self, state): """ Parameters ---------- - state : State + state : StateObject Returns ------- @@ -83,10 +81,10 @@ def infer_bound_from_state(self, state): """ Parameters ---------- - state : State + state : StateObject Returns ------- - state : State + state : StateObject """ return _ffi_api.ComputeDAGInferBoundFromState(self, state) diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index aac062e964fd..fc3821cf7998 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-import, redefined-builtin -""" ... """ +""" Cost model that estimates the performance of programs """ from .cost_model import RandomModel diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index aebc89f465a1..a0e586d69cec 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +""" Cost model that estimates the performance of programs """ import ctypes import numpy as np import tvm._ffi from tvm.runtime import Object - from .. import _ffi_api @@ -32,9 +31,6 @@ class CostModel(Object): @tvm._ffi.register_object("ansor.RandomModel") class RandomModel(Object): - """ - """ - def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py new file mode 100644 index 000000000000..557bb9d3102b --- /dev/null +++ b/python/tvm/ansor/loop_state.py @@ -0,0 +1,439 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import + +""" +The definition of the "state" in search. A state consists a current loop structure +and the transform history to reach its current loop structure. +To enable flexible manipulation of the loop structure, we implemented a lightweight +loop structure IR (Intermediate Representation) specifically for search. + +Basically this is a simplified TVM IR with schedule primitives. +We don't use the existing TVM IR because +1. We want fast incremental change to the loop structures +2. We want serializable history for replay and backtracking +3. We may create some Macro schedule primitives + +After search is done, we will lower this IR to TVM IR with TVM schedule primitives. +Because we share a lot common objects during search, the transformation is +implemented in copy on write style. All objects are immutable, which is +similar to TVM IR. +""" + +import tvm._ffi +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + """A for loop iterator""" + pass + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + """A stage in the compute declaration. Similar to tvm.te.schedule.Stage""" + + @property + def iters(self): + """ + Returns + ------- + iters : List[Iterator] + """ + if not hasattr(self, "iterators_cache"): + setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self)) + return getattr(self, "iterators_cache") + + def iter(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + iter : Iterator + """ + return _ffi_api.StageGetIterator(self, index) + + +@tvm._ffi.register_object("ansor.State") +class StateObject(Object): + """The internal State object """ + def __eq__(self, other): + return _ffi_api.StateEqual(self, other) + + +class State: + """ + A state in the search process. It consists of the current loop structure + and the history steps to reach this state. + + Notes + ----- + This is a wrapper class of StateObject to deal with copy-on-write property + """ + def __init__(self, state_object): + self.state_object = state_object + + self.stages_cache = None + + def clear_cache(self): + self.stages_cache = None + + def copy(self): + return State(self.state_object) + + @property + def stages(self): + """ + Returns + ------- + stages : List[Stage] + """ + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + return self.stages_cache + + def transform_steps_size(self): + """ Return the size of transform_steps + """ + return _ffi_api.StateGetTransformStepsSize(self.state_object) + + def reorder(self, stage_id, order): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to reorder + order : List[Iterator] + Iterators in the expected order + """ + self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + self.clear_cache() + + def split(self, stage_id, it, lengths, inner_to_outer=True): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to split + it : Iterator + The iterator to split + lengths: List[Int] + The split factors + inner_to_outer: Bool + True to use `factor` to split from inner to outer, + False to use `nparts` to split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, it, lengths, + inner_to_outer) + self.clear_cache() + return res + + def follow_split(self, stage_id, it, src_step_id, n_split): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to split + it : Iterator + The iterator to split + src_step_id : Int + The index of the split step to follow in the history + n_split : Int + The number of split level + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, it, + src_step_id, n_split) + self.clear_cache() + return res + + def follow_fused_split(self, stage_id, it, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to split + it : Iterator + The iterator to split + src_step_ids : List[Int] + The indices of the split steps to follow in the history + level : Int + Use the length in this split level + factor_or_nparts : Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, it, + src_step_ids, level, + factor_or_nparts) + self.clear_cache() + return res + + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to fuse + iters : List[Iterator] + The iterators to be fused + + Returns + ------- + res_it : Iterator + The fused Iterator + """ + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) + self.clear_cache() + return res + + def vectorize(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to vectorize + it : Iterator + The iterator to be vectorized + + Returns + ------- + res_it : Iterator + The vectorized Iterator + """ + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, it) + self.clear_cache() + return res + + def parallel(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to parallel + it : Iterator + The iterator to be parallelized + + Returns + ------- + res_it : Iterator + The parallelized Iterator + """ + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, it) + self.clear_cache() + return res + + def unroll(self, stage_id, it, max_unroll=-1): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to unroll + it : Iterator + The iterator to be unrolled + max_unroll: Int + The maximum length of the iterator that can be unrolled + + Returns + ------- + res_it : Iterator + The unrolled Iterator + """ + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, max_unroll) + self.clear_cache() + return res + + def bind_thread(self, stage_id, it, thread_name): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to bind + it : Iterator + The iterator to be bound + thread_name : str + The name of the thread (e.g. "blockIdx.x", "threadIdx.y", "vthread") + + Returns + ------- + res_it : Iterator + The bound Iterator + """ + trans_table = { + "vthread": 4, + "blockIdx.x": 5, + "threadIdx.x": 6, + "blockIdx.y": 7, + "threadIdx.y": 8, + } + thread_id = trans_table[thread_name] + + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, thread_id) + self.clear_cache() + return res + + def compute_at(self, stage_id, target_stage_id, target_iter): + """ + Parameters + ---------- + stage_id : Int + The index of source stage + target_stage_id : Int + The index of the target stage of compute_at + target_iter : Iterator + The target Iterator of compute_at + """ + self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, + target_stage_id, target_iter) + self.clear_cache() + + def compute_root(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to compute root + """ + self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) + self.clear_cache() + + def compute_inline(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to compute inline + """ + self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) + self.clear_cache() + + def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do cache_read + scope_name : Str + reader_stage_ids : List[Int] + task_dag : ComputeDAG + + Returns + ------- + new_stage_id : Int + The added staged id + """ + self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, + scope_name, reader_stage_ids, + task_dag) + self.clear_cache() + return int(new_stage_id) + + def cache_write(self, stage_id, scope_name, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do cache read + scope_name : Str + task_dag : ComputeDAG + + Returns + ------- + new_stage_id : Int + The added staged id + """ + self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, + scope_name, task_dag) + self.clear_cache() + return int(new_stage_id) + + def pragma(self, stage_id, it, pragma_type): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to add pragma + it : Iterator + The iterator to add pragma + pragma_type : Str + """ + self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, it, pragma_type) + self.clear_cache() + + def rfactor(self, stage_id, it, factor_iter_id, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do reduction factor + it : Iterator + factor_iter_id : Int + task_dag : ComputeDAG + + Returns + ------- + new_stage_id : Int + The added staged id + """ + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, it, + factor_iter_id, task_dag) + self.clear_cache() + return int(new_stage_id) + + def storage_align(self, stage_id, it, factor, offset): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do storage align + it : Iterator + factor : Int + offset : Int + + Returns + ------- + state : State + The updated state + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.clear_cache() + + def __str__(self): + return str(self.state_object) + + def __eq__(self, other): + return _ffi_api.StateEqual(self.state_object, other.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index d7d0e64eb14b..5438edfaa6b2 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import + """Distributed measurement infrastructure to measure the runtime costs of tensor programs These functions are responsible for building the tvm module, uploading it to @@ -38,7 +38,6 @@ from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel - from . import _ffi_api logger = logging.getLogger('ansor') diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 172405ce7ddb..bd9a69944057 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -14,21 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +"""Tuning log I/O Utilities""" + import numpy as np import tvm._ffi from tvm.runtime import Object - from .measure import MeasureCallback, MeasureErrorNo - from . import _ffi_api @tvm._ffi.register_object("ansor.LogToFile") class LogToFile(MeasureCallback): """ + A measurement callback that writes tuning logs into a file + Parameters ---------- filename : Str @@ -40,6 +41,13 @@ def __init__(self, filename="ansor_tuning.json"): @tvm._ffi.register_object("ansor.LogReader") class LogReader(Object): + """ + Reader of the json log file + + Parameters + ---------- + filename : Str + """ def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) @@ -56,21 +64,23 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) +def write_measure_records_to_file(filename, inputs, results): + """Write(append) measure records to file""" + _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) + + def best_measure_pair_in_file(filename, workload_key=None, target=None): """ Return best results form log file Parameters ---------- filename : Str - workload_key : Str - target : Str Returns ------- inp : MeasureInput - res : MeasureResult """ log_reader = LogReader(filename) diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py deleted file mode 100644 index aa231ab6f4c6..000000000000 --- a/python/tvm/ansor/state.py +++ /dev/null @@ -1,430 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import -""" ... """ - -import tvm._ffi -from tvm.runtime import Object - -from . import _ffi_api - - -@tvm._ffi.register_object("ansor.Iterator") -class Iterator(Object): - """ ... - """ - pass - - -@tvm._ffi.register_object("ansor.Stage") -class Stage(Object): - """ ... - """ - - def iterator(self, index): - """ - Parameters - ---------- - index : Int - - Returns - ------- - iter : Iterator - """ - return _ffi_api.StageGetIterator(self, index) - - def iterators(self): - """ - Returns - ------- - iters : List[Iterator] - """ - return _ffi_api.StageGetIterators(self) - - -@tvm._ffi.register_object("ansor.State") -class State(Object): - """ ... - """ - - def stage(self, index): - """ - Parameters - ---------- - index : Int - - Returns - ------- - stage : Stage - """ - return _ffi_api.StateGetStage(self, index) - - def transform_steps_size(self): - """ Return the size of transform_steps - """ - return _ffi_api.StateGetTransformStepsSize(self) - - def reorder(self, stage_id, order): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - order : List[Iterator] - Iterators in expected order - - Returns - ------- - state : State - The updated state - """ - state = _ffi_api.StateReorder(self, stage_id, order) - return state - - def split(self, stage_id, it, lengths, inner_to_outer=True): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - lengths: List[Int] - The split factor - inner_to_outer: Bool - True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner - - Returns - ------- - state : State - The updated state - res_its : List[Iterator] - The splited Iterators result - """ - state, res_its = _ffi_api.StateSplit(self, stage_id, it, lengths, - inner_to_outer) - return state, res_its - - def follow_split(self, stage_id, it, src_step_id, n_split): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - src_step_id : Int - The index of target step that this split follows - n_split : Int - Indecate how many level needs to be split out - - Returns - ------- - state : State - The updated state - res_its : List[Iterator] - The splited Iterators result - """ - state, res_its = _ffi_api.StateFollowSplit(self, stage_id, it, - src_step_id, n_split) - return state, res_its - - def follow_fused_split(self, stage_id, it, src_step_ids, level, - factor_or_nparts): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - src_step_ids : List[Int] - The indexes of target step that this split follows - level : Int - factor_or_nparts : Bool - True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner - - Returns - ------- - state : State - The updated state - res_its : List[Iterator] - The splited Iterators result - """ - state, res_its = _ffi_api.StateFollowFusedSplit(self, stage_id, it, - src_step_ids, level, - factor_or_nparts) - return state, res_its - - def fuse(self, stage_id, iters): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - iters : List[Iterator] - The target Iterators to be fused - - Returns - ------- - state : State - The updated state - res_it : Iterator - The fused Iterator - """ - state, res_it = _ffi_api.StateFuse(self, stage_id, iters) - return state, res_it - - def vectorize(self, stage_id, it): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be vectorized - - Returns - ------- - state : State - The updated state - res_it : Iterator - The vectorized Iterator - """ - state, res_it = _ffi_api.StateVectorize(self, stage_id, it) - return state, res_it - - def parallel(self, stage_id, it): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be paralleled - - Returns - ------- - state : State - The updated state - res_it : Iterator - The paralleled Iterator - """ - state, res_it = _ffi_api.StateParallel(self, stage_id, it) - return state, res_it - - def unroll(self, stage_id, it, max_unroll=-1): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be unrolled - max_unroll : Int - - Returns - ------- - state : State - The updated state - res_it : Iterator - The unrolled Iterator - """ - state, res_it = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) - return state, res_it - - def bind_thread(self, stage_id, it, thread_type): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be vectorized - thread_type : ... - Supported type: kVThread, kBlockX, kThreadX, kThreadY - - Returns - ------- - state : State - The updated state - res_it : Iterator - The thread binded Iterator - """ - state, res_it = _ffi_api.StateBindThread(self, stage_id, it, - thread_type) - return state, res_it - - def compute_at(self, stage_id, target_stage_id, target_iter): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - target_stage_id : Int - The index of compute at target stage - target_iter : Iterator - The target Iterator to be compute at - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateComputeAt(self, stage_id, target_stage_id, - target_iter) - - def compute_root(self, stage_id): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateComputeRoot(self, stage_id) - - def compute_inline(self, stage_id): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateComputeInline(self, stage_id) - - def pack_for_vec(self, stage_id, target_iter, vec_size): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - target_iter : Iterator - The target Iterator - vec_size : Int - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StatePackForVec(self, stage_id, target_iter, vec_size) - - def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - scope_name : Str - reader_stage_ids : List[Int] - task_dag : ComputeDAG - - Returns - ------- - state : State - The updated state - new_stage_id : Int - The added staged id - """ - state, new_stage_id = _ffi_api.StateCacheRead(self, stage_id, - scope_name, reader_stage_ids, task_dag) - return state, int(new_stage_id) - - def cache_write(self, stage_id, scope_name, task_dag): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - scope_name : Str - task_dag : ComputeDAG - - Returns - ------- - state : State - The updated state - new_stage_id : Int - The added staged id - """ - state, new_stage_id = _ffi_api.StateCacheWrite(self, stage_id, - scope_name, task_dag) - return state, int(new_stage_id) - - def pragma(self, stage_id, it, pragma_type): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - pragma_type : Str - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StatePragma(self, stage_id, it, pragma_type) - - def rfactor(self, stage_id, it, factor_iter_id, task_dag): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - factor_iter_id : Int - task_dag : ComputeDAG - - Returns - ------- - state : State - The updated state - """ - state, new_stage_id = _ffi_api.StateRfactor(self, stage_id, it, - factor_iter_id, task_dag) - return state, new_stage_id - - def storage_align(self, stage_id, it, factor, offset): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - factor : Int - offset : Int - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateStorageAlign(self, stage_id, it, factor, offset) diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py index 5fab57c28f48..affcf4a6e195 100644 --- a/python/tvm/ansor/task.py +++ b/python/tvm/ansor/task.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +"""Meta information for a search task""" + import random import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner from .cost_model import RandomModel - from . import _ffi_api @@ -137,7 +137,6 @@ class TuneOption(Object): callbacks: List[MeasureCallback] Callback functions """ - def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', callbacks=None): if isinstance(builder, str): diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 0216549c184a..5ed9bd46d355 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -1,4 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Common utilities""" + import multiprocessing import multiprocessing.pool import queue @@ -7,7 +25,6 @@ import os import numpy as np - try: import psutil except ImportError: @@ -31,7 +48,6 @@ def get_func_name(func): name: str The name """ - return func.func_name if hasattr(func, 'func_name') else func.__name__ diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 974e7e5d9f58..a0fa18874a69 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -1,10 +1,31 @@ -#include "auto_schedule.h" +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ -#include +/*! + * \file ansor/auto_schedule.cc + * \brief The user interface of the auto-scheduler + */ +#include "auto_schedule.h" +#include #include #include - #include "search_policy/meta_tile_rewrite_policy.h" namespace tvm { @@ -54,32 +75,32 @@ std::pair > AutoSchedule( } TVM_REGISTER_GLOBAL("ansor.TuneOption") - .set_body_typed([](int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array callbacks) { - return TuneOptionNode::make(n_trials, early_stopping, - num_measure_per_iter, verbose, builder, - runner, callbacks); - }); +.set_body_typed([](int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array callbacks) { + return TuneOptionNode::make(n_trials, early_stopping, + num_measure_per_iter, verbose, builder, + runner, callbacks); +}); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") - .set_body_typed([](SearchTask task, SearchPolicy search_policy, - TuneOption tune_option) { - return AutoSchedule(task, search_policy, tune_option); - }); +.set_body_typed([](SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + return AutoSchedule(task, search_policy, tune_option); +}); TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") - .set_body_typed([](std::string workload_key, Target target, - Target target_host, SearchPolicy search_policy, - HardwareParams hardware_params, TuneOption tune_option) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = - AutoSchedule(workload_key, target, target_host, search_policy, - hardware_params, tune_option); +.set_body_typed([](std::string workload_key, Target target, + Target target_host, SearchPolicy search_policy, + HardwareParams hardware_params, TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = + AutoSchedule(workload_key, target, target_host, search_policy, + hardware_params, tune_option); - return Array{sch, return_tensors}; - }); + return Array{sch, return_tensors}; +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index c354751390fe..f68e844ba776 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -1,12 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_task.h - * \brief Meta information for a search task + * \file ansor/auto_schedule.h + * \brief The user interface of the auto-scheduler */ #ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ #define TVM_ANSOR_AUTO_SCHEDULE_H_ +#include +#include #include "measure.h" namespace tvm { @@ -44,7 +64,7 @@ class TuneOptionNode : public Object { static constexpr const char* _type_key = "ansor.TuneOption"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); }; -TVM_DEFINE_COW_NODE_REF(TuneOption, ObjectRef, TuneOptionNode); +TVM_DEFINE_COW_OBJECT_REF(TuneOption, ObjectRef, TuneOptionNode); /*! \brief Auto schedule for a compute declaration */ State AutoSchedule(SearchTask task, SearchPolicy search_policy, @@ -58,4 +78,4 @@ std::pair > AutoSchedule( } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ \ No newline at end of file +#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 7fad0ce5b28a..f3979ef0d259 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1,6 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/compute_dag.cc + * \brief Compute declaration graph and its related analysis tools */ + #include "compute_dag.h" #include #include @@ -15,7 +36,7 @@ #include #include #include -#include "loop_state.h" +#include "transform_step.h" #include "utils.h" // #include "../relay/pass/kernel_layout_transform.h" @@ -385,6 +406,7 @@ void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, collect(op); } +// Return whether two int arrays are elementwise-equal bool IntArrayEqual(const Array& arr1, const Array& arr2) { if (arr1.size() != arr2.size()) { return false; @@ -543,23 +565,6 @@ class FlopEstimator: public ExprFunctor { bool fail{false}; }; -void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { - if (auto pop = stage->op.as()) { - std::vector& axes = (*stage_to_axes)[stage]; - axes.clear(); - for (const auto& axis : pop->axis) { - axes.push_back(axis); - } - for (const auto& axis : pop->reduce_axis) { - axes.push_back(axis); - } - } else if (stage->op->IsInstance()) { - {} // do nothing - } else { - LOG(FATAL) << "Invalid op " << stage->op; - } -} - State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } @@ -588,13 +593,6 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) return ComputeDAGNode::make(std::move(tens)); } -// Implemented in multi_stage_policy.cc -// Extract primitive iterators from a nested fused or splitted iterator's name -extern void ExtractOriginalIterators(const std::string& name, std::set* rets); - -// Implemented in loop_state.cc -extern std::string CleanName(const std::string& str); - std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } @@ -680,8 +678,8 @@ std::string BaseName(const std::string& str) { // const Operation& op = stage->op; // if (op->IsInstance()) { // const Map& attrs = op->attrs; -// if (attrs.count(_layout_free_placeholders_key)) { -// const ObjectRef& attr_value = attrs[_layout_free_placeholders_key]; +// if (attrs.count(layout_free_placeholders_key)) { +// const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; // Array placeholders = Downcast>(attr_value); // for (auto& placeholder : placeholders) { // const auto placeholder_op = placeholder->op; @@ -907,7 +905,8 @@ std::string BaseName(const std::string& str) { // auto index = old_tensor->value_index; // ptensors->data[i] = new_op.output(index); // } else if (layout_rewrite_level == kComputeRewrite) { -// TensorNode* old_tensor_node = const_cast(old_tensor.as()); +// TensorNode* old_tensor_node = +// const_cast(old_tensor.as()); // old_tensor_node->op = new_op; // } // } @@ -918,6 +917,24 @@ std::string BaseName(const std::string& str) { // } // end for stage // } + +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { + if (auto pop = stage->op.as()) { + std::vector& axes = (*stage_to_axes)[stage]; + axes.clear(); + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + } else if (stage->op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + std::pair > ComputeDAG::ApplySteps( const std::vector& transform_steps, LayoutRewriteLevel layout_rewrite_level) const { @@ -1104,9 +1121,6 @@ std::pair > ComputeDAG::ReplaySteps( UpdateStageAxis(stage, stage_to_axes); } - // todo(lmzheng): should we maintain the attach_map and keep the validity of - // compute_at an splitted axis? - // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); double complete_rate = -1.0; diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 3b4c80c50ad8..60c1790a0cfb 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -1,5 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/compute_dag.h * \brief Compute declaration graph and its related analysis tools */ @@ -22,12 +40,6 @@ namespace ansor { class ComputeDAG; class AccessAnalyzer; class StateNode; class State; class Step; -typedef std::unordered_map, ObjectHash, ObjectEqual> - StageToAxesMap; - -// Update StageToAxes Map during replay -void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); - /*! \brief Read/Write access static analysis result */ class AccessAnalyzerNode : public Object { public: @@ -60,9 +72,11 @@ class AccessAnalyzer : public ObjectRef { // Get all producers of an op void GetProducers(const State& state, const te::Operation& op, std::unordered_set* producers) const; + // Get all consumers of an op. This func deals with inlined op correctly. void GetConsumers(const State& state, const te::Operation& op, std::unordered_set* consumers) const; + // Check whether two ops are elementwise matched // (e.g. conv2d and relu are elementwise matched) bool ElementWiseMatch(const te::Operation& op, @@ -84,17 +98,23 @@ class AccessAnalyzer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); }; +typedef std::unordered_map, ObjectHash, ObjectEqual> + StageToAxesMap; + +// Update StageToAxes Map during replay +void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); + + /*! \brief Compute declaration graph */ class ComputeDAGNode : public Object { public: - Array tensors; // Input and output tensors - Array ops; // All related operations in topo order - double flop_ct; // Number of float operations + Array tensors; // Input and output tensors + Array ops; // All related operations in topo order + double flop_ct; // Number of float operations AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer - ObjectRef init_state; // initial states + ObjectRef init_state; // The initial state void VisitAttrs(tvm::AttrVisitor* v) { - LOG(INFO) << "ComputeDAG"; v->Visit("tensors", &tensors); v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); @@ -126,7 +146,7 @@ class ComputeDAG: public ObjectRef { // Rewrite the the layout of "layout free" placeholders according to transform steps void RewriteLayout(const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {}; + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {} // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; @@ -134,19 +154,21 @@ class ComputeDAG: public ObjectRef { // Replay the transform steps and call ir_pass::InferBound to fill correct bound information State ReplayAndInferBound(const std::vector& transform_steps) const; - // Fill the correct bound information for a given state + // Fill the correct bound information for a given state by calling ir_pass::InferBound State InferBound(const State& state) const; // Fill the correct bound information for a list of given states. // Return the new states inplace void InferBound(std::vector* states) const; - // Replay the transform steps and get the new ops + // Replay the transform steps and get the new DAG void ReplayAndGetDAG(const std::vector& steps, ComputeDAG* task_dag) const; // Get the init state State GetInitState() const; + static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders"; + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); @@ -155,7 +177,6 @@ class ComputeDAG: public ObjectRef { std::pair > ReplaySteps( const std::vector& transform_steps, std::vector* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* _layout_free_placeholders_key = "layout_free_placeholders"; // Internal common parts for inferring bound void InferBoundCommon(StateNode* pstate) const; diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index 060d2b703287..8e0936071774 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -1,6 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/cost_model.h + * \brief Cost model that estimates the performance of programs */ + #include "cost_model.h" #include @@ -23,7 +44,7 @@ void RandomNumber(TVMArgs args, TVMRetValue* rv) { void* data = args[1]; float* fdata = reinterpret_cast(data); for (int i = 0; i < n; i++) { - fdata[i] = static_cast(rand_r(0)) / (static_cast(RAND_MAX)); + fdata[i] = static_cast(rand_r(nullptr)) / (static_cast(RAND_MAX)); } } @@ -130,7 +151,7 @@ void PythonBasedCostModelNode::PredictStages( CHECK_LE(idx, flatten_scores.size()); // Number of scored stages of this state. - int s_length = (int)flatten_scores[idx++]; + int s_length = static_cast(flatten_scores[idx++]); if (s_length > 0) { std::vector scores; diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index 36179573c617..9daf01197bbf 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -1,8 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/cost_model.h - * \brief Base class of cost model - */ + * \brief Cost model that estimates the performance of programs +*/ #ifndef TVM_ANSOR_COST_MODEL_COST_MODEL_H_ #define TVM_ANSOR_COST_MODEL_COST_MODEL_H_ @@ -23,17 +41,24 @@ class CostModel; /*! \brief The base class for cost model */ class CostModelNode: public Object { public: + // Update the cost model according to new measurement pairs virtual void Update(const Array& inputs, const Array& results) = 0; + + // Predict the scores of states virtual void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) = 0; + + // Predict the scores of all stages in states virtual void PredictStages(const SearchTask& task, const std::vector& states, std::vector* state_scores, - std::vector>* stage_scores) = 0; + std::vector>* stage_scores) { + LOG(FATAL) << "Not Implemented"; + } static constexpr const char *_type_key = "ansor.CostModel"; TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(CostModel, CostModelNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(CostModel, CostModelNode); /*! \brief The cost model returns random value for all predictions */ class RandomModelNode: public CostModelNode { @@ -45,14 +70,12 @@ class RandomModelNode: public CostModelNode { void Update(const Array& inputs, const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { ; } static constexpr const char *_type_key = "ansor.RandomModel"; TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); }; +/*! \brief The cost model returns actual cost by measurement */ class MeasureModelNode : public CostModelNode { public: ProgramMeasurer measurer; @@ -62,9 +85,6 @@ class MeasureModelNode : public CostModelNode { void Update(const Array& inputs, const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { ; } static constexpr const char* _type_key = "ansor.MeasureModel"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); diff --git a/src/ansor/expr_hasher.h b/src/ansor/expr_hasher.h deleted file mode 100644 index 1c743ed9a5c4..000000000000 --- a/src/ansor/expr_hasher.h +++ /dev/null @@ -1,97 +0,0 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file auto_scheduler/expr_hasher.h - * \brief Hash function for a tvm::Expr - */ - -#ifndef TVM_ANSOR_EXPR_HASHER_H_ -#define TVM_ANSOR_EXPR_HASHER_H_ - -#include -#include -#include -#include - -namespace tvm { - -/*! \brief Assign a hash value for a tvm::Expr */ -class ExprHasher: public tir::ExprFunctor { - public: - size_t VisitExpr_(const tir::AddNode* op) final { - return VisitExpr(op->a) + VisitExpr(op->b); - } - - size_t VisitExpr_(const tir::SubNode* op) final { - return VisitExpr(op->a) - VisitExpr(op->b); - } - - size_t VisitExpr_(const tir::MulNode* op) final { - return VisitExpr(op->a) * VisitExpr(op->b); - } - - size_t VisitExpr_(const tir::DivNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) / t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5A); - } - } - - size_t VisitExpr_(const tir::FloorDivNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) / t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5B); - } - } - - size_t VisitExpr_(const tir::ModNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) % t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5C); - } - } - - size_t VisitExpr_(const tir::FloorModNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) % t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5D); - } - } - - size_t VisitExpr_(const tir::CallNode* op) final { - size_t ret = ObjectHash()(op->func); - for (size_t i = 0; i < op->args.size(); ++i) { - ret = dmlc::HashCombine(ret, VisitExpr(op->args[i])); - } - return ret; - } - - size_t VisitExpr_(const tir::VarNode* op) final { - return std::hash()(op); - } - - size_t VisitExpr_(const tir::FloatImmNode* op) final { - return std::hash()(op->value); - } - - size_t VisitExpr_(const tir::IntImmNode* op) final { - return std::hash()(op->value); - } - - size_t VisitExprDefault_(const Object* op) final { - LOG(WARNING) << "Encounter undefined node in ExprHasher: " - << Object::_type_key; - return std::hash()(op); - } -}; - -} // namespace tvm - -#endif // TVM_ANSOR_EXPR_HASHER_H_ diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index cb865bc3b5ae..31afe931361c 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -272,7 +272,8 @@ class BufferAccessExtractor : public StmtExprVisitor { this->VisitExpr(expr); } - void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, const Array& indices) { + void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, + const Array& indices) { BufferAccess& acc = buf_accesses[ten]; acc.acc_type = acc_type; acc.indices.push_back(std::vector(indices.begin(), indices.end())); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 32940da0773a..faaac94f3323 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1,18 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/loop_state.h + * \brief An IR (intermediate representation) for loop structures. */ -#include "loop_state.h" +#include "loop_state.h" #include #include - +#include "transform_step.h" #include "utils.h" namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(StageNode); +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +// Maker for other classes +Iterator IteratorNode::make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + return Iterator(node); +} + Stage StageNode::make(te::Operation op) { auto node = make_object(); @@ -43,7 +81,7 @@ Stage StageNode::make(te::Operation op) { Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, + ComputeAtType compute_at, int auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); @@ -57,7 +95,7 @@ Stage StageNode::make(te::Operation op, StageType op_type, Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, - int16_t auto_unroll_max_step, int storage_offset) { + int auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -216,15 +254,6 @@ void State::compute_inline(int stage_id) { return DoComputeInlineStep(step); } -void State::pack_for_vec(int stage_id, const Iterator& target_iter, - int vec_size) { - const Stage& stage = operator->()->stages[stage_id]; - PackForVecStep step = PackForVecStepNode::make( - stage_id, GetIndex(stage->iters, target_iter), vec_size); - CopyOnWrite()->transform_steps.push_back(step); - return DoPackForVecStep(step); -} - Iterator State::bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; @@ -560,10 +589,6 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { pstate->attach_map.DeleteStage(step->stage_id); } -void State::DoPackForVecStep(const PackForVecStep& step) { - LOG(FATAL) << "Not implemented"; -} - // Common part for steps that add new stages // (e.g. CacheReadStep, CacheWriteStep, RfactorStep) void AddStageModificationSteps(size_t step_id, @@ -741,8 +766,6 @@ void State::DoStep(const Step& step, const ComputeDAG& dag) { DoComputeRootStep(GetRef(ps)); } else if (auto ps = step.as()) { DoComputeInlineStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoPackForVecStep(GetRef(ps)); } else if (auto ps = step.as()) { DoCacheReadStep(GetRef(ps), dag); } else if (auto ps = step.as()) { @@ -991,177 +1014,175 @@ AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterator") - .set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterators") - .set_body_typed([](const Stage& stage) { - return Array(stage->iters); - }); - -TVM_REGISTER_GLOBAL("ansor.StateGetStage") - .set_body_typed([](const State& state, int index) { - return state->stages[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize") - .set_body_typed([](const State& state) { - return static_cast(state->transform_steps.size()); - }); +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); +}); + + +TVM_REGISTER_GLOBAL("ansor.StageGetIterator").set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; +}); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { + return Array(stage->iters); +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetStages").set_body_typed([](const State& state) { + return Array(state->stages); +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetStage").set_body_typed([](const State& state, int index) { + return state->stages[index]; +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const State& state) { + return static_cast(state->transform_steps.size()); +}); TVM_REGISTER_GLOBAL("ansor.StateReorder") - .set_body_typed([](State state, int stage_id, - const Array& order) { - std::vector ord; - for (const auto& i : order) { - ord.push_back(i); - } - state.reorder(stage_id, ord); - return state; - }); +.set_body_typed([](State state, int stage_id, const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, bool inner_to_outer) { - std::vector len; - for (const auto& i : lengths) { - len.push_back(i); - } - const auto& res = state.split(stage_id, it, len, inner_to_outer); - return Array{state, Array(res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - std::vector array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = state.follow_fused_split( - stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + std::vector array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateFuse") - .set_body_typed([](State state, int stage_id, - const Array& iters) { - std::vector its; - for (const auto& i : iters) { - its.push_back(i); - } - const auto& res = state.fuse(stage_id, its); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, + const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateVectorize") - .set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.vectorize(stage_id, it); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateParallel") - .set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.parallel(stage_id, it); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateUnroll") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int max_unroll) { - const auto& res = state.unroll(stage_id, it, max_unroll); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateBindThread") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int thread_type) { - const auto& res = - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int thread_type) { + const auto& res = + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateComputeAt") - .set_body_typed([](State state, int stage_id, int target_stage_id, - const Iterator& target_iter) { - state.compute_at(stage_id, target_stage_id, target_iter); - return state; - }); +.set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") - .set_body_typed([](State state, int stage_id) { - state.compute_root(stage_id); - return state; - }); +.set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateComputeInline") - .set_body_typed([](State state, int stage_id) { - state.compute_inline(stage_id); - return state; - }); - -TVM_REGISTER_GLOBAL("ansor.StatePackForVec") - .set_body_typed([](State state, int stage_id, const Iterator& target_iter, - int vec_size) { - state.pack_for_vec(stage_id, target_iter, vec_size); - return state; - }); +.set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateCacheRead") - .set_body_typed([](State state, int stage_id, const std::string& scope_name, - const Array& reader_stage_ids, - const ComputeDAG& task_dag) { - std::vector array_reader_stage_ids; - for (const auto& i : reader_stage_ids) { - array_reader_stage_ids.push_back(i->value); - } - int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, - task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; - }); +.set_body_typed([](State state, int stage_id, const std::string& scope_name, + const Array& reader_stage_ids, + const ComputeDAG& task_dag) { + std::vector array_reader_stage_ids; + for (const auto& i : reader_stage_ids) { + array_reader_stage_ids.push_back(i->value); + } + int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, + task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") - .set_body_typed([](State state, int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { - int res = state.cache_write(stage_id, scope_name, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; - }); +.set_body_typed([](State state, int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; +}); TVM_REGISTER_GLOBAL("ansor.StatePragma") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const std::string& pragma_type) { - state.pragma(stage_id, it, pragma_type); - return state; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + const std::string& pragma_type) { + state.pragma(stage_id, it, pragma_type); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateRfactor") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int factor_iter_id, const ComputeDAG& task_dag) { - int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int factor_iter_id, const ComputeDAG& task_dag) { + int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int factor, int offset) { - state.storage_align(stage_id, it, factor, offset); - return state; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int factor, int offset) { + state.storage_align(stage_id, it, factor, offset); + return state; +}); + +TVM_REGISTER_GLOBAL("ansor.StateEqual") +.set_body_typed([](State state1, State state2) { + return std::equal_to()(state1, state2); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index dd56e267c0a0..90ba48cd92ac 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -1,16 +1,36 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file ansor/interfaces.h - * \brief Data structures for loop transformations +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file ansor/loop_state.h + * \brief The definition of the "state" in search. A state consists a current loop structure + * and the transform history to reach its current loop structure. + * To enable flexible manipulation of the loop structure, we implemented a lightweight + * loop structure IR (Intermediate Representation) specifically for search. + * * Basically this is a simplified TVM IR with schedule primitives. * We don't use the existing TVM IR because * 1. We want fast incremental change to the loop structures * 2. We want serializable history for replay and backtracking - * 3. We want simplified IR for easy and clean feature extraction - * 4. We may create some Macro schedule primitives - - * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. + * 3. We may create some Macro schedule primitives + * + * After search is done, we will lower this IR to TVM IR with TVM schedule primitives. * Because we share a lot common objects during search, the transformation is * implemented in copy on write style. All objects are immutable, which is * similar to TVM IR. @@ -24,24 +44,77 @@ #include #include #include -#include "transform_step.h" +#include "compute_dag.h" namespace tvm { namespace ansor { using namespace tvm::tir; +/*! \brief The type of a stage */ enum StageType { kPlaceholder, kCompute }; +/*! \brief The type of compute location */ enum ComputeAtType { kRoot, // compute at root kInlined, // inlined kIter, // compute at some iterator }; +/*! \brief The type of an iterator */ +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +/*! \brief The type of an iterator's annotation */ +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY +}; + +class Iterator; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; // The original iterators before fusion + + static Iterator make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr); + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_OBJECT_REF(Iterator, ObjectRef, IteratorNode); + +// Forward decelerations class Stage; class State; +class AttachMap; + +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; /*! * \brief A stage in the compute declaration @@ -53,25 +126,32 @@ class StageNode : public Object { StageType op_type; std::vector iters; ComputeAtType compute_at; - int16_t auto_unroll_max_step; + int auto_unroll_max_step; int storage_offset; + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + } + static Stage make(te::Operation op); static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, + ComputeAtType compute_at, int auto_unroll_max_step, int storage_offset); static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, + ComputeAtType compute_at, int auto_unroll_max_step, int storage_offset); static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; -TVM_DEFINE_COW_NODE_REF(Stage, ObjectRef, StageNode); +TVM_DEFINE_COW_OBJECT_REF(Stage, ObjectRef, StageNode); -/*! \brief stores the compute_at relation between stages */ +/*! \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + */ class AttachMapNode: public Object { public: using StageKey = int; @@ -110,6 +190,22 @@ class AttachMap : public ObjectRef { static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); }; +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); + /*! \brief The loop state and corresponding history steps to reach this state */ class StateNode: public Object { public: @@ -125,6 +221,7 @@ class StateNode: public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("complete", &complete); v->Visit("aux_info", &aux_info); + v->Visit("task_dag", &task_dag); } static State make_empty_state(); @@ -137,7 +234,8 @@ class StateNode: public Object { TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); }; -/*! \brief The loop state and corresponding history steps to reach this state */ +/*! \brief A state in the search process. + * It consists of the current loop structure and the history steps to reach this state. */ class State : public ObjectRef { public: // Schedule primitives @@ -154,14 +252,12 @@ class State : public ObjectRef { Iterator vectorize(int stage_id, const Iterator& it); Iterator parallel(int stage_id, const Iterator& it); Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); - // Valide thread_type: kVThread, kBlockX, kThreadX, kThreadY Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); - void pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); @@ -172,8 +268,10 @@ class State : public ObjectRef { const ComputeDAG& task_dag); void storage_align(int stage_id, const Iterator& it, int factor, int offset); - // We separate these functions out, - // so you can call them for replay easily given history steps + /* Do transform steps + * Note: The following functions only change loop state but do not change transform_history. + * We separate these functions out, + * so you can call them for replay easily given history steps */ void DoReorderStep(const ReorderStep& step); std::vector DoSplitStep(const SplitStep& step); std::vector DoFollowSplitStep(const FollowSplitStep& step); @@ -183,38 +281,44 @@ class State : public ObjectRef { void DoComputeAtStep(const ComputeAtStep& step); void DoComputeRootStep(const ComputeRootStep& step); void DoComputeInlineStep(const ComputeInlineStep& step); - void DoPackForVecStep(const PackForVecStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); void DoPragmaStep(const PragmaStep& step); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); void DoStorageAlignStep(const StorageAlignStep& step); - /* Do transform steps - * Note: The following function only change loop state. - * They do not change transform_history. - */ + // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); void DoSteps(const std::vector& step, const ComputeDAG& dag); - // Print to str + // Print the state to a string std::string ToStr(bool delete_trivial_loop = true) const; TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); private: - // common function for DoSplitStep and DoFollowSplitStep + // Common function for DoSplitStep and DoFollowSplitStep std::vector DoSplitStepCommon(int stage_id, int iter_id, const std::vector& lengths, bool inner_to_outer); }; +/*! \brief Clean the name of an iterator to make it valid in python code */ +inline std::string CleanName(const std::string& str) { + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + } // namespace ansor } // namespace tvm -// Hash and equal function for State, Stage, Iterator and Step +// Hash and equal function for State namespace std { template <> diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index b2cff24973bc..43be530f2a35 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -2,15 +2,12 @@ * Copyright (c) 2020 by Contributors */ #include "measure.h" -// #include #include #include - #include #include #include #include -// #include "search_policy/search_policy.h" namespace tvm { namespace ansor { @@ -38,7 +35,7 @@ const char* ErrorNoToStr[] = { "UnknownError", }; -// Maker +// Measure input and result MeasureInput MeasureInputNode::make(SearchTask task, State state) { auto node = make_object(); node->task = std::move(task); @@ -87,6 +84,7 @@ MeasureResult MeasureResultNode::copy() const { return MeasureResult(node); } +// LocalBuilder Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { auto node = make_object(); @@ -96,7 +94,6 @@ Builder LocalBuilderNode::make(int timeout, int n_parallel, return Builder(node); } -// LocalBuilder and LocalRunner Array LocalBuilderNode::Build(const Array& inputs, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { @@ -109,6 +106,7 @@ Array LocalBuilderNode::Build(const Array& inputs, return Array(); } +// RPC Runner Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, int priority, int timeout, int n_parallel, int number, int repeat, int min_repeat_ms, @@ -141,6 +139,7 @@ Array RPCRunnerNode::Run(const Array& inputs, return Array(); } +// Local Runner Runner LocalRunnerNode::make(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval) { ObjectPtr node = make_object(); @@ -166,6 +165,7 @@ Array LocalRunnerNode::Run( return Array(); } +// Program Measurer ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, Array callbacks, int verbose, @@ -284,89 +284,89 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Printing functions TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "MeasureInput()"; - }); +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { - p->stream << "MeasureResult(cost:["; - auto old_config = p->stream.precision(4); - for (size_t i = 0; i < node->costs.size(); ++i) { - auto pf = node->costs[i].as(); - CHECK(pf != nullptr); - p->stream << pf->value; - if (i != node->costs.size() - 1) { - p->stream << ","; - } - } - p->stream.precision(old_config); - p->stream << "], "; - p->stream << "error_no:" << 0 << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } else { - p->stream << "MeasureResult(" - << "error_type:" << ErrorNoToStr[node->error_no] << ", " - << "error_msg:" << node->error_msg << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; } - }); + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "BuildResult(" << node->filename << ", " << node->error_no - << ", " << node->time_cost << ")"; - }); +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; +}); TVM_REGISTER_GLOBAL("ansor.MeasureInput") - .set_body_typed([](SearchTask task, State state) { - return MeasureInputNode::make(task, state); - }); +.set_body_typed([](SearchTask task, State state) { + return MeasureInputNode::make(task, state); +}); TVM_REGISTER_GLOBAL("ansor.BuildResult") - .set_body_typed([](std::string filename, Array args, - int error_no, std::string error_msg, double time_cost) { - return BuildResultNode::make(filename, args, error_no, error_msg, - time_cost); - }); +.set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResultNode::make(filename, args, error_no, error_msg, + time_cost); +}); TVM_REGISTER_GLOBAL("ansor.MeasureResult") - .set_body_typed([](Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { - return MeasureResultNode::make(costs, error_no, error_msg, all_cost, - timestamp); - }); +.set_body_typed([](Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { + return MeasureResultNode::make(costs, error_no, error_msg, all_cost, + timestamp); +}); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") - .set_body_typed([](const Builder& builder, - const Array& inputs, int verbose) { - return builder->Build(inputs, verbose); - }); +.set_body_typed([](const Builder& builder, + const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); +}); TVM_REGISTER_GLOBAL("ansor.RunnerRun") - .set_body_typed([](const Runner& runner, const Array& inputs, - const Array& build_results, int verbose) { - return runner->Run(inputs, build_results, verbose); - }); +.set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, int verbose) { + return runner->Run(inputs, build_results, verbose); +}); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") - .set_body_typed([](int timeout, int n_parallel, - const std::string& build_func) { - return LocalBuilderNode::make(timeout, n_parallel, build_func); - }); +.set_body_typed([](int timeout, int n_parallel, + const std::string& build_func) { + return LocalBuilderNode::make(timeout, n_parallel, build_func); +}); TVM_REGISTER_GLOBAL("ansor.LocalRunner") - .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, - double cooldown_interval) { - return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, - cooldown_interval); - }); +.set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, + cooldown_interval); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 4ea1562315ff..780a30514d46 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -7,7 +7,6 @@ #ifndef TVM_ANSOR_MEASURE_H_ #define TVM_ANSOR_MEASURE_H_ -// #include #include #include #include @@ -22,8 +21,7 @@ class SearchPolicy; class MeasureInput; class BuildResult; class MeasureResult; class Builder; class Runner; class MeasureCallback; class ProgramMeasurer; -extern const char *ErrorNoToStr[]; - +/* \brief The error code of one measurement */ enum MeasureErrorNO { kNoError = 0, // No error kInstantiationError = 1, // Errors happen when apply transform steps from init state @@ -35,14 +33,15 @@ enum MeasureErrorNO { kRunTimeoutError = 7, // Timeout during run kUnknonwError = 8, // Unknown error }; +extern const char *ErrorNoToStr[]; // Inputs and results of one measurement -/* \brief Store the input of a meansurement */ +/* \brief Store the input of a measurement */ class MeasureInputNode: public Object { public: - SearchTask task; - State state; + SearchTask task; // The search task + State state; // The program state to be measured void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); @@ -55,16 +54,16 @@ class MeasureInputNode: public Object { static constexpr const char* _type_key = "ansor.MeasureInput"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); }; -TVM_DEFINE_NODE_REF(MeasureInput, MeasureInputNode); +TVM_DEFINE_OBJECT_REF(MeasureInput, MeasureInputNode); /* \brief Store the input of a build */ class BuildResultNode: public Object { public: - std::string filename; - Array args; - int error_no; - std::string error_msg; - double time_cost; + std::string filename; // The filename of built binary file + Array args; // The arguments + int error_no; // The error code (see MeasureErrorNO). 0 means no error. + std::string error_msg; // The error message if there is any error + double time_cost; // The time cost of build void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("filename", &filename); @@ -80,16 +79,16 @@ class BuildResultNode: public Object { static constexpr const char* _type_key = "ansor.BuildResult"; TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); }; -TVM_DEFINE_NODE_REF(BuildResult, BuildResultNode); +TVM_DEFINE_OBJECT_REF(BuildResult, BuildResultNode); /* \brief Store the results of a measurement */ class MeasureResultNode: public Object { public: - Array costs; - int error_no; - std::string error_msg; - double all_cost; - double timestamp; + Array costs; // The time costs of execution + int error_no; // The error code (see MeasureErrorNO). 0 means no error. + std::string error_msg; // The error message if there is any error + double all_cost; // The time cost of build and run + double timestamp; // The time stamps of this measurement void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("costs", &costs); @@ -107,19 +106,21 @@ class MeasureResultNode: public Object { static constexpr const char* _type_key = "ansor.MeasureResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); }; -TVM_DEFINE_NODE_REF(MeasureResult, MeasureResultNode); +TVM_DEFINE_OBJECT_REF(MeasureResult, MeasureResultNode); -// Measure callback +/* \brief Bass class of measurement callbacks */ class MeasureCallbackNode: public Object { public: + /*! \biref Callback function that will be called on measurement input/result pairs + * after measurement */ virtual void callback(const SearchPolicy& policy, const Array& inputs, const Array& results) = 0; static constexpr const char *_type_key = "ansor.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(MeasureCallback, MeasureCallbackNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(MeasureCallback, MeasureCallbackNode); // Base class for builder and runner @@ -127,21 +128,23 @@ TVM_DEFINE_MUTABLE_NODE_REF(MeasureCallback, MeasureCallbackNode); /* \brief Builder that builds the programs */ class BuilderNode: public Object { public: - int n_parallel; - int timeout; + int n_parallel; // The number of tasks to run in parallel + int timeout; // Timeout of a build + /*! \biref Build programs and return results */ virtual Array Build(const Array& inputs, int verbose) = 0; static constexpr const char* _type_key = "ansor.Builder"; TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(Builder, BuilderNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(Builder, BuilderNode); /* \brief Runner that runs the built programs and measure the time cost */ class RunnerNode: public Object { public: - int timeout; + int timeout; // Timeout of a run + /*! \biref Run measurement and return results */ virtual Array Run(const Array& inputs, const Array& build_results, int verbose) = 0; @@ -149,14 +152,14 @@ class RunnerNode: public Object { static constexpr const char* _type_key = "ansor.Runner"; TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(Runner, RunnerNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(Runner, RunnerNode); // Implementation of various builders and runners /* \brief LocalBuilder use local CPU cores to build programs in parallel */ class LocalBuilderNode: public BuilderNode { public: - std::string build_func; + std::string build_func; // Build function static Builder make(int timeout, int n_parallel, const std::string& build_func); @@ -166,6 +169,7 @@ class LocalBuilderNode: public BuilderNode { TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); }; +/* \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices */ class RPCRunnerNode : public RunnerNode { public: std::string key; @@ -182,6 +186,7 @@ class RPCRunnerNode : public RunnerNode { int priority, int timeout, int n_parallel, int number, int repeat, int min_repeat_ms, double cooldown_interval); + /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -190,7 +195,7 @@ class RPCRunnerNode : public RunnerNode { TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); }; -/* \brief LocalRunner use local CPU/GPU to runs programs in serial and measure the time cost */ +/* \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: int number; @@ -201,6 +206,7 @@ class LocalRunnerNode: public RunnerNode { static Runner make(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); + /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -211,9 +217,8 @@ class LocalRunnerNode: public RunnerNode { /*! - * \brief Measurer measures the time costs of tvm programs - * This class combines Builder and Runner, and provides a simpler API - */ + * \brief Measurer that measures the time costs of tvm programs + * This class combines Builder and Runner, and provides a simpler API */ class ProgramMeasurerNode: public Object { public: static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; @@ -253,7 +258,7 @@ class ProgramMeasurerNode: public Object { static constexpr const char* _type_key = "ansor.ProgramMeasurer"; TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(ProgramMeasurer, ProgramMeasurerNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(ProgramMeasurer, ProgramMeasurerNode); } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index b4501804607a..c22d890a8b51 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/meta_tile_rewrite_policy.h + * \brief The search policy that searches by program sampling and evolutionary search */ #include "meta_tile_rewrite_policy.h" @@ -776,7 +796,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, // Fuse the outermost space tile as blockIdx for (size_t i = 0; i < pop->axis.size(); i++) { const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StringEndWith(it->name, ".0")) { + if (!StrEndsWith(it->name, ".0")) { break; } to_fuse.push_back(it); @@ -788,7 +808,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, to_fuse.clear(); for (size_t i = 1; i < pop->axis.size() + 1; i++) { const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StringEndWith(it->name, ".1")) { + if (!StrEndsWith(it->name, ".1")) { break; } to_fuse.push_back((*state)->stages[stage_id]->iters[i]); @@ -804,7 +824,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, to_fuse.clear(); for (size_t i = 2; i < pop->axis.size() + 2; i++) { const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StringEndWith(it->name, ".2")) { + if (!StrEndsWith(it->name, ".2")) { break; } to_fuse.push_back((*state)->stages[stage_id]->iters[i]); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index ca9033ad866e..0c8c44b9c5ea 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -1,100 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/meta_tile_rewrite_policy.h - * \brief A search policy that search with meta tiling structure and random - * rewrite + * \file ansor/search_policy/meta_tile_rewrite_policy.h + * \brief The search policy that searches by program sampling and evolutionary search */ + #ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#include +#include #include -#include #include -#include - +#include +#include +#include "search_policy.h" #include "../cost_model/cost_model.h" #include "../utils.h" -#include "search_policy.h" + namespace tvm { namespace ansor { /*! Multi stage search policy */ -class MetaTileRewritePolicyNode : public SearchPolicyNode { +class MetaTileRewritePolicyNode: public SearchPolicyNode { public: CostModel program_cost_model; /* this->params is used to store the following arguments - * int evolutionary_search_population - * The population size for evolutionary search - * int evolutionary_search_mutation_prob - * The probability of mutation for evolutionary search - * int evolutionary_search_num_iters - * The number of iterations for evolutionary search - * double local_mutation_use_measured_ratio - * The maximum percentage of measured states in the initial population - * for evolutionary search - * double eps_greedy - * Always allocate this percentage of measurements to random sampled states - * str cpu_multi_level_tiling_structure - * The structure of multi-level tiling for CPU - * str gpu_multi_level_tiling_structure - * The structure of multi-level tiling for GPU + * int evolutionary_search_population // The population size for evolutionary search + * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search + * int evolutionary_search_num_iters; // The number of iterations for evolutionary search + * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial + * // population for evolutionary search + * double eps_greedy; // Always allocate this percentage of measurements to random sampled states + * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU + * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; static SearchPolicy make(CostModel program_cost_model, - Map params, int seed); + Map params, + int seed); // Search and make n_trails measurements // Return the best state - State Search(SearchTask task, int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, - ProgramMeasurer measurer) final; + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) final; // Continue search. This is used by JointTuner std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, - ProgramMeasurer measurer) final; + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - static constexpr const char* _type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; static const std::vector auto_unroll_configs; TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); - SearchTask cur_task_; // The current task + SearchTask cur_task_; // The current task + friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, - const std::vector& random_states, - int remaining_n_trials); + const std::vector& random_states, int remaining_n_trials); private: // Run one round of the search pipeline - void SearchOneRound(std::vector* best_states, int num_random_states, - std::vector* random_states); + void SearchOneRound(std::vector* best_states, + int num_random_states, std::vector* random_states); // Synthesize meta tiling structure without tile size void SynthesizeMetaStructure(std::vector* out_states); // Sample init population void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states); + int out_size, std::vector* out_states); // Perform evolutionary search void EvolutionarySearch(const std::vector& init_population, - int num_best_states, std::vector* best_states); + int num_best_states, std::vector* best_states); SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator int verbose_; // Verbose level (0 means silent) - int num_measure_per_iter_; // The number of states to measure per iteration + int num_measure_per_iter_; // The number of states to measure per iteration - // The set of the already measured states. We store the string format for - // redundancy check + // The set of the already measured states. We store the string format for redundancy check std::unordered_set measured_states_set_; // The array of already measured states. diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 89bfeb1a8edd..866922d0001e 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/search_policy.cc + * \brief The base class for search policy */ #include "search_policy.h" @@ -11,4 +31,3 @@ TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); } // namespace ansor } // namespace tvm - diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 5bd9fb3118b1..f2071deab447 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -1,8 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_policy.h - * \brief Base class of search policy + * \file ansor/search_policy/search_policy.h + * \brief The base class for search policy */ + #ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ @@ -45,7 +64,7 @@ class SearchPolicyNode : public Object { static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(SearchPolicy, SearchPolicyNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index 9c597b4eb811..608b89da118c 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/utils.cc + * \brief Common utilities for search policies */ #include "utils.h" @@ -42,27 +62,6 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia } } -// Query axes that should not be splitted according to the attribute from tvm.compute -std::pair, std::set > QueryNoSplitAxis(const Stage& stage) { - std::pair, std::set > ret; - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { - ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); - } - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { - ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); - } - return ret; -} - -// Query axes that last split is one -std::set QueryLastSplitIsOneAxis(const Stage& stage) { - std::set ret; - if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { - ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); - } - return ret; -} - // Apply multi-tiling structure according to a string format State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids) { @@ -413,7 +412,7 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen // Mutate a parallel loop. State MutataParallel(const State& state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, SearchTask& task, int verbose) { + std::mt19937* random_gen, const SearchTask& task, int verbose) { // To make this mutation simple but promising, we only focus on a specific case that // parallel was added to the outermost loop and the loop is generated by fusing other loops. // In short, we mutate the step pattern of (fuse -> parallel). @@ -574,17 +573,6 @@ void GridMutateTileSize(const State& old_state, std::vector* cands, } } -// Random choose an index according to a prefix sum probability -int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { - std::uniform_real_distribution<> dis(0.0, 1.0); - double x = dis(*random_gen); - - CHECK(!prefix_sum_probs.empty()); - - return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - - prefix_sum_probs.begin(); -} - // Prune undefined states. void PruneUndefined(std::vector* states) { size_t pt = 0; diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 3337975d7a88..607a549e1b8a 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -1,7 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_policy/utils.h - * \brief Common utilities for local mutation in search policy + * \file ansor/search_policy/utils.cc + * \brief Common utilities for search policies */ #ifndef TVM_ANSOR_SEARCH_POLICY_UTILS_H_ @@ -15,20 +33,13 @@ #include #include "../cost_model/cost_model.h" #include "../utils.h" +#include "../loop_state.h" +#include "../transform_step.h" #include "search_policy.h" namespace tvm { namespace ansor { -inline bool StringEndWith(const std::string& str, const std::string& target) { - int str_len = str.length(); - int target_len = target.length(); - if (str_len <= target_len) { - return false; - } - return str.compare(str_len - target_len, target_len, target) == 0; -} - // Get an integer from a tvm str Map inline int GetIntParam(const Map& attr_dict, const std::string& key) { @@ -96,7 +107,8 @@ inline int64_t GetExtent(const Iterator& it) { } // Return whether an op is strict inlineable -inline bool IsStrictInlineable(const SearchTask& task, const State& state, const te::Operation& op) { +inline bool IsStrictInlineable(const SearchTask& task, + const State& state, const te::Operation& op) { if (state->task_dag.defined()) { return state->task_dag->access_analyzer.IsStrictInlineable(op); } else { @@ -132,7 +144,8 @@ inline bool HasReduceIter(const Stage& stage) { } // Return whether an op needs multi level tiling -inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { +inline bool NeedsMultilevelTiling(const SearchTask& task, + const State& state, const te::Operation& op) { if (state->task_dag.defined()) { return state->task_dag->access_analyzer.NeedsMultiLevelTiling(op); } else { @@ -142,7 +155,7 @@ inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, co // Get all consumers for an op. This will take inline into consideration inline void GetConsumers(const SearchTask& task, const State& state, const te::Operation& op, - std::unordered_set* consumers) { + std::unordered_set* consumers) { if (state->task_dag.defined()) { state->task_dag->access_analyzer.GetConsumers(state, op, consumers); } else { @@ -161,7 +174,7 @@ inline void GetProducers(const SearchTask& task, const State& state, const te::O // Return whether two ops are elementwise-matched inline bool ElementwiseMatch(const SearchTask& task, const State& state, const te::Operation& op, - const te::Operation& target_op) { + const te::Operation& target_op) { if (state->task_dag.defined()) { return state->task_dag->access_analyzer.ElementWiseMatch(op, target_op); } else { @@ -171,8 +184,7 @@ inline bool ElementwiseMatch(const SearchTask& task, const State& state, const t // Return whether the stage has only one consumer and they are elementwise-matched inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, - const State& state, const Stage& stage, - int* target_stage_id) { + const State& state, const Stage& stage, int* target_stage_id) { std::unordered_set consumers; GetConsumers(task, state, stage->op, &consumers); @@ -203,8 +215,8 @@ inline bool NeedsRfactor(const SearchTask& task, const State& state, const te::O if (NeedsMultilevelTiling(task, state, op)) { // Do not use rfactor if we have enough parallelism on space iters - if (cum_space_len > cum_reduce_len - || cum_space_len > task->hardware_params->num_cores * 16) { + if (cum_space_len > cum_reduce_len || + cum_space_len > task->hardware_params->num_cores * 16) { return false; } else { return true; @@ -240,6 +252,7 @@ inline bool HasCacheWriteStage(const State& s, int stage_id) { return false; } +// Return whether the state did cache_read for stage_id inline bool HasCacheReadStage(const State& s, int stage_id) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { if (auto ps = s->transform_steps[i].as()) { @@ -261,8 +274,10 @@ inline bool HasCacheReadStage(const State& s, int stage_id) { return false; } +// Get all split step on spatial iterators void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); +// Return whether the state did split/follow_split/follow_fused_split in stage_id inline bool HasSplitStep(const State& s, int stage_id) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { if (s->transform_steps[i]->IsInstance() || @@ -290,9 +305,26 @@ inline bool IsTiled(const Stage& stage) { } // Query axes that should not be splitted according to the attribute from tvm.compute -std::pair, std::set > QueryNoSplitAxis(const Stage& stage); +inline std::pair, std::set > QueryNoSplitAxis( + const Stage& stage) { + std::pair, std::set > ret; + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { + ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); + } + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { + ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); + } + return ret; +} + // Query axes that last split is one -std::set QueryLastSplitIsOneAxis(const Stage& stage); +inline std::set QueryLastSplitIsOneAxis(const Stage& stage) { + std::set ret; + if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { + ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); + } + return ret; +} // Extract primitive iterators from a nested fused or splitted iterator's name inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { @@ -329,6 +361,7 @@ inline const Iterator& GetLastSpaceIteratorInOutermostTile(const Stage& stage) { return stage->iters[0]; } +// Get the last reduce iterator in the outermost reduce tile inline const Iterator& GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { auto pop = stage->op.as(); CHECK(pop != nullptr); @@ -379,10 +412,15 @@ inline void RandomSampleStates(const std::vector& in_states, std::mt19937 } // Random choose an index according to a prefix sum probability -int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen); +inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); -// Prune undefined states. -void PruneUndefined(std::vector* states); + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} // Print all states inline void PrintAllStates(const std::vector& states) { @@ -418,7 +456,7 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen // Mutate a parallel loop. State MutataParallel(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, SearchTask& task, int verbose = 0); + std::mt19937* random_gen, const SearchTask& task, int verbose = 0); // Create all possible tile size states for all SplitStep void GridMutateTileSize(const State& old_state, std::vector* cands, @@ -427,6 +465,9 @@ void GridMutateTileSize(const State& old_state, std::vector* cands, // GA: Crossover two states State CrossOverState(const State& p1, const State& p2); +// Prune undefined states. +void PruneUndefined(std::vector* states); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 93f3f60ea768..c65516150f30 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -1,12 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_task.cc + * \brief Meta information and hardware parameters for a search task */ -#include "search_task.h" +#include "search_task.h" #include #include #include - #include #include @@ -118,21 +137,21 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, } TVM_REGISTER_GLOBAL("ansor.HardwareParams") - .set_body_typed([](int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor) { - return HardwareParamsNode::make(num_cores, vector_unit_bytes, - cache_line_bytes, max_unroll_vec, - max_innermost_split_factor); - }); +.set_body_typed([](int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + return HardwareParamsNode::make(num_cores, vector_unit_bytes, + cache_line_bytes, max_unroll_vec, + max_innermost_split_factor); +}); TVM_REGISTER_GLOBAL("ansor.SearchTask") - .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, - HardwareParams hardware_params) { - return SearchTaskNode::make(compute_dag, workload_key, target, - target_host, hardware_params); - }); +.set_body_typed([](ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + return SearchTaskNode::make(compute_dag, workload_key, target, + target_host, hardware_params); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 9512013848b6..cfa5500c39f4 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -1,36 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/search_task.h - * \brief Meta information for a search task + * \brief Meta information and hardware parameters for a search task */ #ifndef TVM_ANSOR_SEARCH_TASK_H_ #define TVM_ANSOR_SEARCH_TASK_H_ #include - #include - #include "compute_dag.h" namespace tvm { namespace ansor { -class HardwareParams; -class SearchTask; +class HardwareParams; class SearchTask; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { public: + // The number of cores int num_cores; + // The width of vector units in bytes int vector_unit_bytes; + // The size of cache line in bytes int cache_line_bytes; - // The max length of the axis to be unrolled or vectorized + // The max length of an axis to be unrolled or vectorized int max_unroll_vec; // The max split factor for the innermost tile int max_innermost_split_factor; - // Limit params for GPU schedule + // Limitation params for GPU int max_shared_memory_per_block{INT32_MAX}; int max_registers_per_block{INT32_MAX}; int max_threads_per_block{INT32_MAX}; @@ -54,13 +72,14 @@ class HardwareParamsNode : public Object { static HardwareParams make(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, int max_innermost_split_factor); + static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; -TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); +TVM_DEFINE_COW_OBJECT_REF(HardwareParams, ObjectRef, HardwareParamsNode); /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { @@ -86,7 +105,7 @@ class SearchTaskNode : public Object { static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; -TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); +TVM_DEFINE_COW_OBJECT_REF(SearchTask, ObjectRef, SearchTaskNode); } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index fc4917409cc0..53c75a13f197 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -1,57 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/serialization.cc + * \brief Json serialization format for dumping and loading tuning records */ -#include "serialization.h" #include #include - #include #include +#include #include #include -#include - +#include "serialization.h" #include "loop_state.h" +#include "transform_step.h" #include "utils.h" // Json serialization handler for MeasureInput, MeasureResult -// (and recursively SearchTask, State, Step, ... +// (and recursively for SearchTask, State, Step, ...) namespace dmlc { namespace json { -inline std::vector& FloatArrayToVector( - std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { +inline std::vector& IntArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto& x : data) { - auto pf = x.as<::tvm::tir::FloatImmNode>(); - CHECK(pf != nullptr) << "Cost can only contain float values"; - out->push_back(pf->value); - } - return *out; -} - -inline std::vector& IntArrayToVector( - std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { - out->clear(); - for (const auto& x : data) { + for (const auto&x : data) { auto pi = x.as<::tvm::tir::IntImmNode>(); - CHECK(pi != nullptr) << "Cost can only contain int values"; + CHECK(pi != nullptr) << "Can only contain int values"; out->push_back(pi->value); } return *out; } template <> -struct Handler> { +struct Handler > { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Stage>& data) { + const std::vector<::tvm::ansor::Stage> & data) { // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Stage>* data) { + std::vector<::tvm::ansor::Stage> * data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); @@ -59,16 +67,16 @@ struct Handler> { }; template <> -struct Handler> { +struct Handler > { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Step>& data) { + const std::vector<::tvm::ansor::Step> & data) { std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { writer->WriteArraySeperator(); writer->BeginArray(false); if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { - writer->WriteArrayItem(std::string("RS")); + writer->WriteArrayItem(std::string("RE")); writer->WriteArrayItem(ps->stage_id); writer->WriteArraySeperator(); @@ -78,7 +86,7 @@ struct Handler> { } writer->EndArray(); } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { - writer->WriteArrayItem(std::string("SS")); + writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); if (ps->extent.defined()) { @@ -89,14 +97,13 @@ struct Handler> { writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); } else if (auto ps = data[i].as<::tvm::ansor::FollowSplitStepNode>()) { - writer->WriteArrayItem(std::string("FSS")); + writer->WriteArrayItem(std::string("FSP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->src_step_id); writer->WriteArrayItem(ps->n_split); - } else if (auto ps = - data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { - writer->WriteArrayItem(std::string("FFSS")); + } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { + writer->WriteArrayItem(std::string("FFSP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); @@ -110,7 +117,7 @@ struct Handler> { writer->WriteArrayItem(ps->level); writer->WriteArrayItem(static_cast(ps->factor_or_nparts)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { - writer->WriteArrayItem(std::string("FS")); + writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); writer->WriteArraySeperator(); @@ -120,7 +127,7 @@ struct Handler> { } writer->EndArray(); } else if (auto ps = data[i].as<::tvm::ansor::AnnotationStepNode>()) { - writer->WriteArrayItem(std::string("AS")); + writer->WriteArrayItem(std::string("AN")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(static_cast(ps->annotation)); @@ -145,12 +152,12 @@ struct Handler> { writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->scope_name); } else if (auto ps = data[i].as<::tvm::ansor::PragmaStepNode>()) { - writer->WriteArrayItem(std::string("PS")); + writer->WriteArrayItem(std::string("PR")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->pragma_type); } else if (auto ps = data[i].as<::tvm::ansor::RfactorStepNode>()) { - writer->WriteArrayItem(std::string("RFS")); + writer->WriteArrayItem(std::string("RF")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->factor_iter_id); @@ -167,8 +174,9 @@ struct Handler> { } writer->EndArray(); } + inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Step>* data) { + std::vector<::tvm::ansor::Step> * data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; std::string name, scope_name, pragma_type; @@ -181,14 +189,13 @@ struct Handler> { reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); reader->Read(&name); - if (name == "RS") { + if (name == "RE") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back( - ::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); - } else if (name == "SS") { + data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); @@ -203,7 +210,7 @@ struct Handler> { stage_id, iter_id, extent, std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); - } else if (name == "FSS") { + } else if (name == "FSP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); @@ -214,7 +221,7 @@ struct Handler> { reader->Read(&n_split); data->push_back(::tvm::ansor::FollowSplitStepNode::make( stage_id, iter_id, src_step_id, n_split)); - } else if (name == "FFSS") { + } else if (name == "FFSP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); @@ -227,21 +234,21 @@ struct Handler> { reader->Read(&factor_or_nparts); data->push_back(::tvm::ansor::FollowFusedSplitStepNode::make( stage_id, iter_id, int_list, level, factor_or_nparts)); - } else if (name == "FS") { + } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); data->push_back(::tvm::ansor::FuseStepNode::make(stage_id, int_list)); - } else if (name == "AS") { + } else if (name == "AN") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStepNode::make( - stage_id, iter_id, ::tvm::ansor::IteratorAnnotation(ann))); + data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, + iter_id, ::tvm::ansor::IteratorAnnotation(ann))); } else if (name == "CA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -273,26 +280,26 @@ struct Handler> { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&scope_name); - data->push_back( - ::tvm::ansor::CacheWriteStepNode::make(stage_id, scope_name)); - } else if (name == "PS") { + data->push_back(::tvm::ansor::CacheWriteStepNode::make( + stage_id, scope_name)); + } else if (name == "PR") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&pragma_type); - data->push_back( - ::tvm::ansor::PragmaStepNode::make(stage_id, iter_id, pragma_type)); - } else if (name == "RFS") { + data->push_back(::tvm::ansor::PragmaStepNode::make( + stage_id, iter_id, pragma_type)); + } else if (name == "RF") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStepNode::make(stage_id, iter_id, - factor_iter_id)); + data->push_back(::tvm::ansor::RfactorStepNode::make( + stage_id, iter_id, factor_iter_id)); } else if (name == "SA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -392,7 +399,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { writer->BeginArray(false); writer->WriteArraySeperator(); writer->BeginArray(false); - for (const auto& x : data.costs) { + for (const auto&x : data.costs) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; writer->WriteArrayItem(pf->value); @@ -434,7 +441,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ANSOR_LOG_VERSION = "v0.1"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); @@ -442,21 +449,24 @@ MeasureCallback LogToFileNode::make(std::string filename) { return MeasureCallback(node); } -void WriteMeasureRecords(std::ostream* os, const Array& inputs, +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, const Array& results) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { writer.BeginObject(false); writer.WriteObjectKeyValue("i", *inputs[i].operator->()); writer.WriteObjectKeyValue("r", *results[i].operator->()); - writer.WriteObjectKeyValue("v", ansor_LOG_VERSION); + writer.WriteObjectKeyValue("v", ANSOR_LOG_VERSION); writer.EndObject(); *os << "\n"; } } -void ReadMeasureRecords(std::string str, MeasureInputNode* inp, - MeasureResultNode* res, std::string* log_version) { +void ReadMeasureRecord(const std::string& str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version) { std::istringstream ss(str); dmlc::JSONReader reader(&ss); std::string key; @@ -499,7 +509,7 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { } try { - ReadMeasureRecords(cur_line, inp, res, &log_version); + ReadMeasureRecord(cur_line, inp, res, &log_version); } catch (...) { return false; } @@ -510,8 +520,8 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { return false; } -std::pair, Array> LogReaderNode::ReadLines( - int max_size, int skip_size) { +std::pair, Array > LogReaderNode::ReadLines( + int max_size, int skip_size) { auto inp = make_object(); auto res = make_object(); Array inputs; @@ -534,41 +544,68 @@ std::pair, Array> LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") - .set_body([](TVMArgs args, TVMRetValue* ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); - }); +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target) { + std::pair best_pair; + double best_cost = 1e30; + + auto inp = make_object(); + auto res = make_object(); + LogReader reader = LogReaderNode::make(filename); + + while (reader->ReadNext(inp.get(), res.get())) { + if (res->error_no != kNoError || inp->task->workload_key != workload_key + || inp->task->target->target_name != target->target_name) { + continue; + } + + double cost = FloatArrayMean(res->costs); + + if (cost < best_cost) { + best_cost = cost; + best_pair = std::make_pair(inp->copy(), res->copy()); + } + } + + return best_pair; +} + +TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); TVM_REGISTER_GLOBAL("ansor.LogToFile") - .set_body_typed([](const std::string& filename) { - return LogToFileNode::make(filename); - }); +.set_body_typed([](const std::string& filename) { + return LogToFileNode::make(filename); +}); TVM_REGISTER_GLOBAL("ansor.LogReader") - .set_body_typed([](const std::string& filename) { - return LogReaderNode::make(filename); - }); +.set_body_typed([](const std::string& filename) { + return LogReaderNode::make(filename); +}); TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") - .set_body_typed([](LogReader reader, int size, int skip_size) { - const auto& res = reader->ReadLines(size, skip_size); - return Array{res.first, res.second}; - }); +.set_body_typed([](LogReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; +}); TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") - .set_body_typed([](LogReader reader) { - auto inp = make_object(); - auto res = make_object(); - if (reader->ReadNext(inp.get(), res.get())) { - return Array{ObjectRef(inp), ObjectRef(res)}; - } else { - return Array(); - } - }); +.set_body_typed([](LogReader reader) { + auto inp = make_object(); + auto res = make_object(); + if (reader->ReadNext(inp.get(), res.get())) { + return Array{ObjectRef(inp), ObjectRef(res)}; + } else { + return Array(); + } +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index ef4132169652..a12760bb3acc 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -1,5 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/serialization.h * \brief Json serialization format for dumping and loading tuning records */ @@ -7,18 +25,15 @@ #ifndef TVM_ANSOR_SERIALIZATION_H_ #define TVM_ANSOR_SERIALIZATION_H_ -#include #include +#include #include - #include "measure.h" namespace tvm { namespace ansor { -class LogReader; - -/*! \brief Log the input and results of measurments to file */ +/*! \brief Callback for logging the input and results of measurements to file */ class LogToFileNode : public MeasureCallbackNode { public: std::string filename; @@ -26,13 +41,16 @@ class LogToFileNode : public MeasureCallbackNode { static MeasureCallback make(std::string filename); /*! \brief Log measure pairs to file. This is called by the search policy */ - void callback(const SearchPolicy& policy, const Array& inputs, + void callback(const SearchPolicy& policy, + const Array& inputs, const Array& results) final; - static constexpr const char* _type_key = "ansor.LogToFile"; + static constexpr const char *_type_key = "ansor.LogToFile"; TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; +class LogReader; + /*! \brief Log reader */ class LogReaderNode : public Object { public: @@ -49,7 +67,7 @@ class LogReaderNode : public Object { * \param max_size The maximum number of lines. -1 means read all lines * \param skip_size Skip the first n lines */ std::pair, Array > ReadLines( - int max_size = -1, int skip_size = 0); + int max_size = -1, int skip_size = 0); static constexpr const char* _type_key = "ansor.LogReader"; TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); @@ -57,17 +75,23 @@ class LogReaderNode : public Object { private: std::string cur_line; }; -TVM_DEFINE_MUTABLE_NODE_REF(LogReader, LogReaderNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(LogReader, LogReaderNode); -void WriteMeasureRecords(std::ostream* os, const Array& inputs, +/*! \brief Write measure records to an output stream */ +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, const Array& results); -void ReadMeasureRecords(std::string str, MeasureInputNode* inp, - MeasureResultNode* res, std::string* log_version); +/*! \brief Read one measure record from a string */ +void ReadMeasureRecord(const std::string& str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version); -std::pair BestMeasurePairInFile( - const std::string& filename, const std::string& workload_key, - const Target& target); +/*! \brief Return the best measure pair with lowest cost in a file */ +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target); } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 5f4a6a8dcef9..3f59ff736e9d 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -1,16 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/transform_step.cc + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. + * + * See the note in transform_step.h on how to add a new step */ + #include "transform_step.h" #include +#include #include "utils.h" namespace tvm { namespace ansor { -TVM_REGISTER_NODE_TYPE(IteratorNode); -TVM_REGISTER_OBJECT_TYPE(StepNode); - /********** Reorder **********/ ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { auto node = make_object(); @@ -226,7 +247,8 @@ FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, return FollowFusedSplitStep(node); } -PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength( + const std::vector& transform_steps) const { PrimExpr ret(1); for (int src_step_id : src_step_ids) { @@ -402,7 +424,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** Compute at **********/ +/********** Compute At **********/ ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { auto node = make_object(); node->stage_id = stage_id; @@ -487,29 +509,7 @@ std::string ComputeInlineStepNode::PrintAsPythonAPI( return ss.str(); } -/********** Pack for vec **********/ -PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->vec_size = vec_size; - return PackForVecStep(node); -} - -void PackForVecStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - LOG(FATAL) << "Not implemented"; -} - -std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - LOG(FATAL) << "Not implemented"; - return ""; -} - -/********** Cache read **********/ +/********** Cache Read **********/ CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, const std::vector& reader_stage_ids) { auto node = make_object(); @@ -572,7 +572,7 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** Cache write **********/ +/********** Cache Write **********/ CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { auto node = make_object(); node->stage_id = stage_id; @@ -770,8 +770,7 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** StorageAlign **********/ - +/********** Storage Align **********/ StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, int factor, int offset) { auto node = make_object(); @@ -802,20 +801,5 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( return ss.str(); } -// Maker for other classes -Iterator IteratorNode::make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { - auto node = make_object(); - node->name = std::move(name); - node->range = std::move(range); - node->iter_type = iter_type; - node->annotation = annotation; - if (ori_iters != nullptr) { - node->ori_iters = *ori_iters; - } - return Iterator(node); -} - } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 627ce02b60e1..8240623ae3b1 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -1,106 +1,30 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file ansor/transform_step.h - * \brief Data structures for loop transformations - - * Basically this is a simplified TVM IR with schedule primitives. - * We don't use the existing TVM IR because - * 1. We want fast incremental change to the loop structures - * 2. We want serializable history for replay and backtracking - * 3. We want simplified IR for easy and clean feature extraction - * 4. We may create some Macro schedule primitives - - * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. - * Because we share a lot common objects during search, the transformation is - * implemented in copy on write style. All objects are immutable, which is - * similar to TVM IR. +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ -#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ -#define TVM_ANSOR_TRANSFORM_STEP_H_ - -#include -#include -#include -#include "compute_dag.h" - -namespace tvm { -namespace ansor { - -using namespace tvm::tir; - -inline std::string CleanName(const std::string& str) { - // to make the name valid in python code - std::string ret = str; - StrReplace(&ret, ".", "_"); - StrReplace(&ret, "@", "_"); - StrReplace(&ret, "outer", "o"); - StrReplace(&ret, "inner", "i"); - return ret; -} - -enum IteratorType { - kSpace, // spatial iterator - kReduce, // reduction iterator - kMixed, // fused spatial and reduction iterator - kSpecial // special iterator (e.g. virtual root iterator) -}; - -enum IteratorAnnotation { - kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY -}; - -class Iterator; - /*! - * \brief An for loop iterator - * Similar to tvm::IterVar in `include/expr.h` - */ -class IteratorNode : public Object { - public: - std::string name; - Range range; // domain of for loop range - IteratorType iter_type; - IteratorAnnotation annotation; - std::vector ori_iters; - - static Iterator make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("range", &range); - } - - static constexpr const char *_type_key = "ansor.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); - -/*! \brief The base class for a transformation step */ -class StepNode: public Object { - public: - int stage_id; - - // Print step as equivalent python schedule API - virtual std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const = 0; - - static constexpr const char* _type_key = "ansor.Step"; - TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); -}; -TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); - -/* - * Note on how to add a new transform step + * \file ansor/transform_step.h + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. * + * \Note How to add a new transform step. * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function - * in FuseStepNode::make(...) loop_state.cc + * 1. Define class FuseStepNode, FuseStep in transform_steps.h, and implement its make function + * in FuseStepNode::make(...) transform_steps.cc * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. * - In these two functions you need to lower this step with tvm's schedule API * 3. Implement State::fuse and State::DoFuseStep. @@ -112,17 +36,24 @@ TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) */ -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class PackForVecStep; class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class AttachMap; +#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ +#define TVM_ANSOR_TRANSFORM_STEP_H_ + +#include +#include +#include +#include "loop_state.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; +/*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode: public StepNode { public: - std::vector after_ids; + std::vector after_ids; // The iterator ids after reorder. + // This array should specify the order of all iterators. static ReorderStep make(int stage_id, const std::vector& after_ids); @@ -137,15 +68,17 @@ class ReorderStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); - +TVM_DEFINE_COW_OBJECT_REF(ReorderStep, Step, ReorderStepNode); +/*! \brief Split step that corresponds to te::Stage::split with additional + * support of multiple-level of factors */ class SplitStepNode: public StepNode { public: - int iter_id; - PrimExpr extent; // the extent of the axis to split + int iter_id; // The id of the iter to split + PrimExpr extent; // the extent length of the axis to split std::vector lengths; // The split factors - bool inner_to_outer; + bool inner_to_outer; // If true, the `lengths` denote the lengths of + // iterators from inner level to outer level static SplitStep make(int stage_id, int iter_id, PrimExpr extent, const std::vector& lengths, @@ -162,15 +95,15 @@ class SplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); +TVM_DEFINE_COW_OBJECT_REF(SplitStep, Step, SplitStepNode); -// Similar to SplitStepNode, but use split factor from another step -// (i.e. Follow another split step) +/*! \brief Similar to SplitStepNode, but use split factor from another step + * (i.e. Follow another split step) */ class FollowSplitStepNode: public StepNode { public: - int iter_id; - int src_step_id; - int n_split; + int iter_id; // The id of the iter to split + int src_step_id; // The index of the split step to follow in the history + int n_split; // The number of split level static FollowSplitStep make(int stage_id, int iter_id, int src_step_id, int n_split); @@ -190,17 +123,17 @@ class FollowSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); - +TVM_DEFINE_COW_OBJECT_REF(FollowSplitStep, Step, FollowSplitStepNode); -// Similar to FollowSplitStep, but use split factors from multiple steps -// This can be used for the split in cooperative fetching. +/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. + * \Note This can be used for the split in cooperative fetching + */ class FollowFusedSplitStepNode: public StepNode { public: - int iter_id; - std::vector src_step_ids; - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + int iter_id; // The id of the iter to split + std::vector src_step_ids; // The indices of the split steps to follow in the history + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts static FollowFusedSplitStep make(int stage_id, int iter_id, const std::vector& src_step_ids, @@ -220,12 +153,12 @@ class FollowFusedSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); - +TVM_DEFINE_COW_OBJECT_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +/*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: - std::vector fused_ids; + std::vector fused_ids; // The ids of iterators to fuse static FuseStep make(int stage_id, const std::vector& fused_ids); @@ -240,9 +173,11 @@ class FuseStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); - +TVM_DEFINE_COW_OBJECT_REF(FuseStep, Step, FuseStepNode); +/*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. + * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) + */ class AnnotationStepNode: public StepNode { public: int iter_id; @@ -261,9 +196,9 @@ class AnnotationStepNode: public StepNode { static constexpr const char* _type_key = "ansor.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); - +TVM_DEFINE_COW_OBJECT_REF(AnnotationStep, Step, AnnotationStepNode); +/*! \brief Fuse step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode: public StepNode { public: int target_stage_id; @@ -283,9 +218,9 @@ class ComputeAtStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); - +TVM_DEFINE_COW_OBJECT_REF(ComputeAtStep, Step, ComputeAtStepNode); +/*! \brief Fuse step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: static ComputeRootStep make(int stage_id); @@ -301,9 +236,9 @@ class ComputeRootStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); - +TVM_DEFINE_COW_OBJECT_REF(ComputeRootStep, Step, ComputeRootStepNode); +/*! \brief Fuse step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode: public StepNode { public: static ComputeInlineStep make(int stage_id); @@ -319,31 +254,9 @@ class ComputeInlineStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); +TVM_DEFINE_COW_OBJECT_REF(ComputeInlineStep, Step, ComputeInlineStepNode); -class PackForVecStepNode: public StepNode { - public: - int iter_id; - int vec_size; - - static PackForVecStep make(int stage_id, int iter_id, int vec_size); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PackForVecStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); - - -/*! \brief Apply cache_read to a stage - * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ +/*! \brief Cache read step that corresponds to te::Schedule::cache_read */ class CacheReadStepNode: public StepNode { public: std::string scope_name; @@ -363,12 +276,10 @@ class CacheReadStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheReadStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); +TVM_DEFINE_COW_OBJECT_REF(CacheReadStep, Step, CacheReadStepNode); - -/*! \brief Apply cache_write to a stage - * TVM Api: te::Schedule::cache_write(tensor, scope) - * This step will cache_write all output tensors of target stage */ +/*! \brief Cache read step that corresponds to te::Schedule::cache_write + * \Note This step will cache_write all output tensors of target stage */ class CacheWriteStepNode: public StepNode { public: std::string scope_name; @@ -386,9 +297,9 @@ class CacheWriteStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheWriteStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); +TVM_DEFINE_COW_OBJECT_REF(CacheWriteStep, Step, CacheWriteStepNode); -/*! \brief Add pragma to a specific iterator */ +/*! \brief Cache read step that corresponds to te::Schedule::pragma */ class PragmaStepNode: public StepNode { public: int iter_id; @@ -407,10 +318,9 @@ class PragmaStepNode: public StepNode { static constexpr const char* _type_key = "ansor.PragmaStep"; TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); +TVM_DEFINE_COW_OBJECT_REF(PragmaStep, Step, PragmaStepNode); -/*! \brief Factor a reduction axis - * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ +/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ class RfactorStepNode: public StepNode { public: int iter_id; @@ -430,8 +340,9 @@ class RfactorStepNode: public StepNode { static constexpr const char* _type_key = "ansor.RfactorStep"; TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); +TVM_DEFINE_COW_OBJECT_REF(RfactorStep, Step, RfactorStepNode); +/*! \brief Storage align step that corresponds to te::Schedule::storage_align */ class StorageAlignStepNode: public StepNode { public: int iter_id; @@ -452,12 +363,12 @@ class StorageAlignStepNode: public StepNode { static constexpr const char* _type_key = "ansor.StorageAlignStep"; TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); +TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); } // namespace ansor } // namespace tvm -// Hash and equal function for State, Stage, Iterator and Step +// Hash and equal function for Step namespace std { template <> @@ -515,32 +426,27 @@ struct hash<::tvm::ansor::Step> { } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { return ::dmlc::HashCombine(9, ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { - return ::dmlc::HashCombine(10, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->vec_size))); } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { - return ::dmlc::HashCombine(11, + return ::dmlc::HashCombine(10, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->scope_name), ps->reader_stage_ids))); } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { - return ::dmlc::HashCombine(12, + return ::dmlc::HashCombine(11, ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->scope_name)); } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { - return ::dmlc::HashCombine(13, + return ::dmlc::HashCombine(12, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->pragma_type))); } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { - return ::dmlc::HashCombine(14, + return ::dmlc::HashCombine(13, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->factor_iter_id))); } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { - return ::dmlc::HashCombine(15, + return ::dmlc::HashCombine(14, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ::dmlc::HashCombine(std::hash()(ps->factor), diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc index 2018cf33d1a2..27aac7e8b315 100644 --- a/src/ansor/utils.cc +++ b/src/ansor/utils.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/utils.cc + * \brief Common utilities */ #include "utils.h" @@ -8,7 +28,6 @@ namespace tvm { namespace ansor { - NullStream& NullStream::Global() { static NullStream stream; return stream; diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 67ebb836c680..cb90364b01b5 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -1,5 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/utils.h * \brief Common utilities */ @@ -25,7 +43,7 @@ namespace std { -// hash function for std::pair, std::vector and std::tuple +/*! \brief Hash function for std::pair */ template struct hash > { std::size_t operator()(const std::pair& k) const { @@ -33,6 +51,7 @@ struct hash > { } }; +/*! \brief Hash function for std::tuple */ template struct hash > { std::size_t operator()(const std::tuple& k) const { @@ -42,6 +61,7 @@ struct hash > { } }; +/*! \brief Hash function for std::vector */ template struct hash > { std::size_t operator()(const std::vector& vec) const { @@ -61,38 +81,37 @@ struct hash > { namespace tvm { namespace ansor { -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ +/*! \brief Macro to make it easy to define object ref type given node */ +#define TVM_DEFINE_OBJECT_REF(TypeName, ObjectName) \ class TypeName : public ObjectRef { \ public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ }; \ -/*! \brief Macro to make it easy to define mutable node ref type given node */ -#define TVM_DEFINE_MUTABLE_NODE_REF(TypeName, NodeName) \ +/*! \brief Macro to make it easy to define mutable object ref type given node */ +#define TVM_DEFINE_MUTABLE_OBJECT_REF(TypeName, ObjectName) \ class TypeName : public ObjectRef { \ public: \ - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ }; \ /*! * \brief Macro to make it easy to define node ref type that * has a CopyOnWrite member function. */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, NodeName); \ - TVM_DEFINE_OBJECT_REF_COW_METHOD(NodeName); \ +#define TVM_DEFINE_COW_OBJECT_REF(TypeName, BaseType, ObjectName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, ObjectName); \ + TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName); \ }; -/********** Utilities for std::vector, std::set **********/ - +/********** Utilities for std::vector, std::set, std::string **********/ /*! \brief Get the first appearance index of elements in a vector */ template inline void GetIndices(const std::vector& array, - const std::vector& to_locate, - std::vector* indices) { + const std::vector& to_locate, + std::vector* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); if (it != array.end()) { @@ -133,7 +152,7 @@ inline int64_t ElementProduct(const std::vector& array) { return ret; } -/* \brief Get the maximum element in a vector */ +/*! \brief Get the maximum element in a vector */ template T MaximumElement(const std::vector& array) { CHECK(!array.empty()); @@ -162,7 +181,7 @@ std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args return *out; } -/* \brief Get a random permutation of integers [0, n-1] */ +/*! \brief Get a random permutation of integers [0, n-1] */ template void RandomPermutation(int n, std::vector* out, G* gen) { out->assign(n, 0); @@ -170,7 +189,7 @@ void RandomPermutation(int n, std::vector* out, G* gen) { std::shuffle(out->begin(), out->end(), *gen); } -/* \brief Random sample without replacement */ +/*! \brief Random sample without replacement */ template void RandomSample(std::vector* in_data, size_t out_size, G* gen) { // Note: This function is inefficient in the cases when out_size << in_data.size() @@ -204,43 +223,19 @@ inline void Argsort(const std::vector& scores, std::vector* index) { std::sort(index->begin(), index->end(), cmp); } -// Do x++ for all x in the set such that x >= threshold -inline void SetAddOne(std::set* set, int threshold = 0) { - std::set new_set; - for (int x : *set) { - if (x >= threshold) { - new_set.insert(x + 1); - } else { - new_set.insert(x); - } - } - *set = std::move(new_set); -} - -// Compute Jaccard Similarity of two sets -template -double JaccardSimilarity(std::set s1, std::set s2) { - std::vector intersect; - std::set_intersection(s1.begin(), s1.end(), s2.begin(), s2.end(), - std::back_inserter(intersect)); - return 1.0 * intersect.size() / (s1.size() + s2.size() - intersect.size()); -} - -/********** Utilities for std::string **********/ - -/*! Return whether a string ends with a another substring */ +/*! \brief Return whether a string ends with another substring */ inline bool StrEndsWith(const std::string& a, const std::string& b) { if (b.size() > a.size()) return false; return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); } -/*! Return whether a string starts with a another substring */ +/*! \brief Return whether a string starts with another substring */ inline bool StrStartsWith(const std::string& a, const std::string& b) { if (b.size() > a.size()) return false; return std::equal(a.begin(), a.begin() + b.size(), b.begin()); } -/*! Replace a sub-string to another sub-string in a string */ +/*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); while (pos != std::string::npos) { @@ -250,7 +245,6 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st } /********** Utilities for TVM Containers / ByteArray **********/ - /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { double sum = 0; @@ -266,51 +260,15 @@ inline double FloatArrayMean(const Array& float_array) { return sum / float_array.size(); } -/*! \brief Serialize a 2-dimensional vector to TVMByteArray. - * This is used for sending data to python code */ -template -inline TVMByteArray Serialize2dVector(std::vector >&& in_data, - std::vector* out_data) { - size_t total_bytes = 0; - std::vector size_vector; - - // serialize sizes - total_bytes += (1 + in_data.size()) * sizeof(int); - size_vector.reserve(in_data.size() + 1); - size_vector.push_back(in_data.size()); - for (const auto& x : in_data) { - size_vector.push_back(static_cast(x.size())); - total_bytes += sizeof(T) * x.size(); - } - - out_data->reserve(total_bytes); - char* ptr = out_data->data(); - memmove(ptr, reinterpret_cast(size_vector.data()), (1 + in_data.size()) * sizeof(int)); - ptr += (1 + in_data.size()) * sizeof(int); - - // serialize in_data - for (auto& x : in_data) { - memmove(ptr, x.data(), sizeof(T) * x.size()); - ptr += sizeof(T) * x.size(); - x.clear(); - } - - CHECK_EQ(ptr - out_data->data(), total_bytes); - - return TVMByteArray{out_data->data(), total_bytes}; -} - /********** Other Utilities **********/ - -// Get an int value from an Expr +/*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { auto pint = expr.as(); CHECK(pint != nullptr); return pint->value; } - -// Compute the product of the lengths of axes +/*! \brief Compute the product of the lengths of axes */ inline int64_t AxisLengthProd(const Array& axes) { int64_t ret = 1.0; for (const auto& x : axes) { @@ -323,8 +281,7 @@ inline int64_t AxisLengthProd(const Array& axes) { return ret; } - -// An empty output stream +/*! \brief An empty output stream */ class NullStream : public std::ostream { public: NullStream() : std::ostream(nullptr) {} diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 00e748204fde..5f1dea0f1ea5 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -21,28 +21,13 @@ #include #include #include - +#include #include -#include "../../src/ansor/feature.h" +// todo(jcf94): do not use relative path #include "../../src/ansor/loop_state.h" -#include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" -#include "../../src/ansor/serialization.h" - -tvm::Array matmul_func(int n, int m, int k) { - using namespace tvm; - using namespace tvm::te; - - Tensor A = placeholder({n, k}, DataType::Float(32), "A"); - Tensor B = placeholder({k, m}, DataType::Float(32), "B"); - IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); - const auto& C = compute( - {n, m}, [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, - "C"); - - return {A, B, C}; -} +// Compute declaration for test tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO, int kernel_size, @@ -91,17 +76,7 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, using namespace tvm::ansor; -TEST(ComputeDAG, Basic) { - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = ComputeDAGNode::make(tensors); - const auto& state = StateNode::make(dag->ops); - CHECK(std::equal_to()(state, dag.GetInitState())); - - LOG(INFO) << "\n" << state; - LOG(INFO) << "\n" << dag; - LOG(INFO) << "\n" << dag->access_analyzer; -} - +// Test Access Analyzer TEST(ComputeDAG, GetProducersConsumers) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); @@ -166,570 +141,6 @@ TEST(ComputeDAG, GetProducersConsumers) { } } -TEST(ComputeDAG, InferBoundSerialization) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - int A = 0, B = 1, C = 2; - - State s0 = dag.GetInitState(); - int C_global = s0.cache_write(C, "global", dag); - C++; - const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); - const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); - s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], - its1[3]}); - s0.compute_at(C_global, C, s0->stages[C]->iters[3]); - s0.split(C_global, s0->stages[C_global]->iters[2], {16}); - int B_global = s0.cache_read(B, "global", {C_global}, dag); - C++; - C_global++; - s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); - int A_global = s0.cache_read(A, "global", {C_global}, dag); - B++; - B_global++; - C++; - C_global++; - s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); - - const auto& s1 = dag.InferBound(s0); - std::vector s2 = {s0}; - dag.InferBound(&s2); - const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); - - CHECK_EQ( - s1->stages[B_global]->iters[0]->range->extent.as()->value, - 512); - CHECK_EQ( - s1->stages[B_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ( - s1->stages[A_global]->iters[0]->range->extent.as()->value, 1); - CHECK_EQ( - s1->stages[A_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ( - s1->stages[C_global]->iters[0]->range->extent.as()->value, - 64); - CHECK(std::equal_to()(s1, s2[0])); - CHECK(std::equal_to()(s1, s3)); - - const auto& minp0 = MeasureInputNode::make( - SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), HardwareParams()), - s0); - const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); - std::stringstream ss; - WriteMeasureRecords(&ss, {minp0}, {mres0}); - auto minp1 = tvm::make_object(); - auto mres1 = tvm::make_object(); - std::string log_version; - ReadMeasureRecords(ss.str(), minp1.get(), mres1.get(), &log_version); - const auto& s4 = dag.ReplayAndInferBound(minp1->state->transform_steps); - CHECK(std::equal_to()(s1, s4)); -} - -TEST(Step, SplitFuseReorder) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - - State s0 = dag.GetInitState(); - State s1 = s0; - Iterator ti = s0->stages[2]->iters[0]; - Iterator tj = s0->stages[2]->iters[1]; - Iterator tk = s0->stages[2]->iters[2]; - std::vector its; - - CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); - - its = s0.split(2, ti, {16}); - Iterator tio = its[0], tii = its[1]; - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); - CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); - - its = s0.split(2, tj, {8}); - Iterator tjo = its[0], tji = its[1]; - CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); - CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); - - s0.reorder(2, {tio, tjo, tk, tji, tii}); - CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); - - s0.fuse(2, {tio, tjo}); - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, - 2048); - - s1.split(2, ti, {8, 2}); - s1.split(2, tj, {32, 8}, false); - CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 32); - CHECK_EQ(s1->stages[2]->iters[1]->range->extent.as()->value, 8); - CHECK_EQ(s1->stages[2]->iters[2]->range->extent.as()->value, 2); - CHECK_EQ(s1->stages[2]->iters[3]->range->extent.as()->value, 32); - CHECK_EQ(s1->stages[2]->iters[4]->range->extent.as()->value, 8); - CHECK_EQ(s1->stages[2]->iters[5]->range->extent.as()->value, 2); -} - -TEST(Step, ComputeAtRootInline) { - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); - // int data = 0, padding = 1, kernel = 2; - int conv = 3; - // int bias = 4; - int bias_add = 5; - // int bn_scale = 6; - int bn_mul = 7; - // int bn_offset = 8; - int bn_add = 9, relu = 10; - - State s0 = dag.GetInitState(); - s0.compute_inline(bn_add); - s0.compute_inline(bn_mul); - s0.compute_inline(bias_add); - s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); - const auto& conv_stage_attach = - s0->attach_map->stage_to_attach_iter.find(conv); - std::pair iterkey(relu, 2); - CHECK(conv_stage_attach->second == iterkey); - const auto& conv_iter_attach = - s0->attach_map->iter_to_attached_stages.find(iterkey); - CHECK_EQ(conv_iter_attach->second.size(), 1); - CHECK_EQ(conv_iter_attach->second[0], conv); - std::stringstream ss; - ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" - << "for ax1 (0,3)\n" - << " for ax2 (0,230)\n" - << " for ax3 (0,230)\n" - << " T_pad = ...\n" - << "for ax1 (0,64)\n" - << " for ax2 (0,112)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " for i (None)\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw = ...\n" - << " for ax3 (0,112)\n" - << " T_relu = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); - - s0.compute_root(conv); - s0.compute_root(bn_mul); - CHECK_EQ(s0->attach_map->stage_to_attach_iter.size(), 0); - CHECK_EQ(s0->attach_map->iter_to_attached_stages.size(), 0); - ss.str(std::string()); - ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" - << "for ax1 (0,3)\n" - << " for ax2 (0,230)\n" - << " for ax3 (0,230)\n" - << " T_pad = ...\n" - << "for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " for i (None)\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw = ...\n" - << "for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Bn_mul = ...\n" - << "for ax1 (0,64)\n" - << " for ax2 (0,112)\n" - << " for ax3 (0,112)\n" - << " T_relu = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); -} - -TEST(Step, CacheReadWrite) { - using namespace tvm; - using namespace tvm::te; - - const auto& test_func = []() -> Array { - int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; - int padding = 1; - Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); - Tensor kernel_data = - placeholder({CO, CI, KH, KW}, DataType::Float(32), "Kernel_data"); - const auto& k_split = compute( - kernel_data->shape, - [&](const Array& i) { - return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, - div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); - }, - "Kernel_split"); - const auto& kernel = compute( - kernel_data->shape, - [&](Var i, Var j, Var k, Var l) { - return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; - }, - "Kernel"); - const auto& conv = - topi::conv2d_nchw(data, kernel, padding, padding, stride, stride); - const auto& relu = topi::relu(conv); - const auto& out = compute( - relu->shape, - [&](Var i, Var j, Var k, Var l) { - return data[i][j][k][l] + relu[i][j][k][l]; - }, - "Add"); - return {data, kernel_data, out}; - }; - const auto& dag = ComputeDAGNode::make(test_func()); - - int data = 0, pad_temp = 1, kernel_data = 2, kernel_split = 3, kernel = 4; - int conv = 5, relu = 6, add = 7; - - // 0: init state - auto s0 = dag.GetInitState(); - std::vector ori_its = s0->stages[add]->iters; - auto its = s0.split(add, s0->stages[add]->iters[0], {2}); - s0.reorder(add, {its[0], ori_its[1], its[1], ori_its[2], ori_its[3]}); - s0.compute_inline(relu); - - // 1: simple cache_write with compute_at - int conv_global = s0.cache_write(conv, "global", dag); - conv++; - relu++; - add++; - s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); - - // 2: simple cache_read with compute_at - int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); - conv_global++; - conv++; - relu++; - add++; - s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); - std::stringstream ss; - ss << "Placeholder: Data, Kernel_data\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,9)\n" - << " for ax3 (0,9)\n" - << " T_pad = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " for ax0_c (None)\n" - << " for ax1_c (None)\n" - << " for ax2_c (None)\n" - << " for ax3_c (None)\n" - << " for i (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Kernel.global = ...\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw.global = ...\n" - << " T_conv2d_nchw = ...\n" - << "for ax0.0 (0,2)\n" - << " for ax1 (0,512)\n" - << " for ax0.1 (0,2)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Add = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); - - // 3: two level cache_read with compute_at - // preparing for GPU's shared memory & local memory - int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); - kernel_data++; - kernel_split++; - kernel++; - kernel_global++; - conv_global++; - conv++; - relu++; - add++; - int pad_temp_shared = - s0.cache_read(pad_temp_global, "shared", {conv_global}, dag); - kernel_data++; - kernel_split++; - kernel++; - kernel_global++; - conv_global++; - conv++; - relu++; - add++; - s0.compute_at(pad_temp_global, conv_global, - s0->stages[conv_global]->iters[2]); - s0.compute_at(pad_temp_shared, conv_global, - s0->stages[conv_global]->iters[4]); - - // 4: cache_read with multi readers - // This stage cannot be compute at to its consumer - s0.cache_read(data, "global", {pad_temp, add}, dag); - pad_temp++; - pad_temp_global++; - pad_temp_shared++; - kernel_data++; - kernel_split++; - kernel++; - kernel_global++; - conv_global++; - conv++; - relu++; - add++; - ss.str(std::string()); - ss << "Placeholder: Data, Kernel_data\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Data.global = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,9)\n" - << " for ax3 (0,9)\n" - << " T_pad = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " for ax0_c (None)\n" - << " for ax1_c (None)\n" - << " for ax2_c (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global = ...\n" - << " for ax3_c (None)\n" - << " for i (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Kernel.global = ...\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global.shared = ...\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw.global = ...\n" - << " T_conv2d_nchw = ...\n" - << "for ax0.0 (0,2)\n" - << " for ax1 (0,512)\n" - << " for ax0.1 (0,2)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Add = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); - - // 5: cache_write with multi outputs - // TVM's cache_write actually has a bug with this case: - - // After schedule.cache_write, TVM generate one new stage: - // From: kernel_data -> kernel_split -> kernel - // To: kernel_data -> kernel_split_global -> kernel_split -> kernel - - // But with topo sort analyse, we get: - // kernel_data -> kernel_split_global -> kernel_split -> kernel - // \ / - // ----------------> kernel_split ----------------> - - // Seems there's bug with the input/output tensor. Such multi outputs case - // should be unusual, so we make some hack on DoCacheWrite - // To be fixed in the future - s0.cache_write(kernel_split, "global", dag); - ss.str(std::string()); - ss << "Placeholder: Data, Kernel_data\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Data.global = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,9)\n" - << " for ax3 (0,9)\n" - << " T_pad = ...\n" - << "for ax0_c (0,512)\n" - << " for ax1_c (0,512)\n" - << " for ax2_c (0,3)\n" - << " for ax3_c (0,3)\n" - << " Kernel_split.global = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " for ax0_c (None)\n" - << " for ax1_c (None)\n" - << " for ax2_c (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global = ...\n" - << " for ax3_c (None)\n" - << " for i (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Kernel.global = ...\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global.shared = ...\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw.global = ...\n" - << " T_conv2d_nchw = ...\n" - << "for ax0.0 (0,2)\n" - << " for ax1 (0,512)\n" - << " for ax0.1 (0,2)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Add = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); -} - -TEST(Step, FollowSplitFollowFusedSplit) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - - State s0 = dag.GetInitState(); - int C = 2; - - int C_global = s0.cache_write(C, "global", dag); - C++; - - // FollowSplitStep currently only support `inner_to_outer = true` - const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); - int split_step0 = s0->transform_steps.size() - 1; - // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, - // false); int split_step1 = s0->transform_steps.size() - 1; - for (int level = 1; level <= 5; level++) { - State tmp = s0; - tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, - level); - // tmp.follow_split(C_global, s0->stages[C_global]->iters[5], split_step1, - // level); - const auto& stage_C = tmp->stages[C]; - const auto& stage_C_global = tmp->stages[C_global]; - for (int i = 0; i < level; i++) { - CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, - stage_C_global->iters[i]->range->extent.as()->value); - } - // for (int i = 0; i < level; i++) { - // CHECK(stage_C->iters[i+5]->range->extent.as()->value == - // stage_C_global->iters[i+5]->range->extent.as()->value); - // } - } - - const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {2, 2, 4, 8}); - int split_step1 = s0->transform_steps.size() - 1; - std::vector its; - for (int i = 0; i < 5; i++) { - its.push_back(its0[i]); - its.push_back(its1[i]); - } - s0.reorder(C, its); - for (int i = 0; i < 5; i++) { - s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i + 1]}); - } - for (int level = 0; level < 4; level++) { - State tmp = s0; - tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], - {split_step0, split_step1}, level, false); - const auto& stage_C = tmp->stages[C]; - const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, - stage_C_global->iters[0]->range->extent.as()->value); - } - for (int level = 0; level < 4; level++) { - State tmp = s0; - tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], - {split_step0, split_step1}, level, true); - const auto& stage_C = tmp->stages[C]; - const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, - stage_C_global->iters[1]->range->extent.as()->value); - } -} - -TEST(Step, Rfactor) { - // todo -} - -TEST(Feature, ExtractionMatmul) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - State s0 = dag.GetInitState(); - - Iterator ti = s0->stages[2]->iters[0]; - Iterator tj = s0->stages[2]->iters[1]; - Iterator tk = s0->stages[2]->iters[2]; - std::vector its; - its = s0.split(2, ti, {16}); - Iterator tio = its[0], tii = its[1]; - its = s0.split(2, tj, {8}); - Iterator tjo = its[0], tji = its[1]; - s0.reorder(2, {tio, tjo, tk, tji, tii}); - s0.vectorize(2, tji); - s0.parallel(2, tio); - s0.parallel(2, tjo); - s0.unroll(2, tk); - - int max_n_bufs = 5; - std::vector> features; - std::vector feature_names; - GetPerStmtFeatureName(max_n_bufs, &feature_names); - GetPerStmtFeaturesFromStates( - {s0}, - SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), HardwareParams()), - max_n_bufs, 0, &features); - int num_states = 1; - CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); - // TODO(...): Add feature check here -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index d701ef5b7bbd..cd8a1eedb162 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -14,13 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import random -import os -import numpy as np -import tvm -from tvm import te -from tvm import ansor +"""Common functions for ansor test cases""" + + +from tvm import te, ansor import topi @@ -59,507 +57,26 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation return [data, kernel, bias, bn_offset, bn_scale, out] -def test_compute_dag_basic(): - dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) - - print(dag) - print(dag.access_analyzer) - print(dag.get_init_state()) - - -def test_state_split_fuse_reorder(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) - s0 = dag.get_init_state() - s1 = s0 - ti = s0.stage(2).iterator(0) - tj = s0.stage(2).iterator(1) - tk = s0.stage(2).iterator(2) - - assert ti.range.extent == 512 - - s0, its = s0.split(2, ti, [16]) - tio = its[0] - tii = its[1] - assert s0.stage(2).iterator(0).range.extent == 32 - assert s0.stage(2).iterator(1).range.extent == 16 - - s0, its = s0.split(2, tj, [8]) - tjo = its[0] - tji = its[1] - assert s0.stage(2).iterator(2).range.extent == 64 - assert s0.stage(2).iterator(3).range.extent == 8 - - s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) - assert s0.stage(2).iterator(2).range.extent == 512 - - s0, res_it = s0.fuse(2, [tio, tjo]) - assert res_it.range.extent == 2048 - - s1, _ = s1.split(2, ti, [8, 2]) - s1, _ = s1.split(2, tj, [32, 8], False) - assert s1.stage(2).iterator(0).range.extent == 32 - assert s1.stage(2).iterator(1).range.extent == 8 - assert s1.stage(2).iterator(2).range.extent == 2 - assert s1.stage(2).iterator(3).range.extent == 32 - assert s1.stage(2).iterator(4).range.extent == 8 - assert s1.stage(2).iterator(5).range.extent == 2 - - -def test_state_compute_at_root_inline(): - dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) - - # data, padding, kernel = 0, 1, 2 - conv = 3 - # bias = 4 - bias_add = 5 - # bn_scale = 6 - bn_mul = 7 - # bn_offset = 8 - bn_add, relu = 9, 10 - - s0 = dag.get_init_state() - s0 = s0.compute_inline(bn_add) - s0 = s0.compute_inline(bn_mul) - s0 = s0.compute_inline(bias_add) - s0 = s0.compute_at(conv, relu, s0.stage(relu).iterator(2)) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - s0 = s0.compute_root(conv) - s0 = s0.compute_root(bn_mul) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - "for i (None)\n" + \ - " for j (None)\n" + \ - " for k (None)\n" + \ - " for l (None)\n" + \ - " Bn_mul = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - -def test_state_cache_read_write(): - N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( - 1, 1), (1, 1) - - data = te.placeholder((N, CI, H, W), name='Data') - kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') - k0, k1 = te.compute(kernel_data.shape, - lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), - name='Kernel_split') - kernel = te.compute(kernel_data.shape, - lambda *i: k0(*i) + k1(*i), - name='Kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) - relu = topi.nn.relu(conv) - out = topi.add(data, relu) - - dag = ansor.ComputeDAG([data, kernel_data, out]) - data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 - - # 0: init state - s0 = dag.get_init_state() - ori_its = s0.stage(add).iterators() - s0, its = s0.split(add, s0.stage(add).iterator(0), [2]) - s0 = s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) - s0 = s0.compute_inline(relu) - - # 1: simple cache_write with compute_at - s0, conv_global = s0.cache_write(conv, "global", dag) - conv += 1 - relu += 1 - add += 1 - s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) - - # 2: simple cache_read with compute_at - s0, kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0 = s0.compute_at(kernel_global, conv_global, - s0.stage(conv_global).iterator(4)) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 3: two level cache_read with compute_at - # preparing for GPU's shared memory & local memory - s0, pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0, pad_temp_shared = s0.cache_read( - pad_temp_global, "shared", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0 = s0.compute_at(pad_temp_global, conv_global, - s0.stage(conv_global).iterator(2)) - s0 = s0.compute_at(pad_temp_shared, conv_global, - s0.stage(conv_global).iterator(4)) - - # 4: cache_read with multi readers - # This stage cannot be compute at to its consumer - s0, data_global = s0.cache_read(data, "global", [pad_temp, add], dag) - pad_temp += 1 - pad_temp_global += 1 - pad_temp_shared += 1 - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 5: cache_write with multi outputs - # See tests/cpp/ansor_test.cc for more information - s0, _ = s0.cache_write(kernel_split, "global", dag) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0_c (0,512)\n" + \ - " for i1_c (0,512)\n" + \ - " for i2_c (0,3)\n" + \ - " for i3_c (0,3)\n" + \ - " Kernel_split.global = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - -def test_follow_split_follow_fused_split(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) - s0 = dag.get_init_state() - C = 2 - - s0, C_global = s0.cache_write(C, "global", dag) - C += 1 - - s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) - split_step0 = s0.transform_steps_size() - 1 - for level in range(1, 6): - tmp = s0 - tmp, _ = tmp.follow_split(C_global, tmp.stage( - C_global).iterator(0), split_step0, level) - for i in range(0, level): - assert tmp.stage(C).iterator(i).range.extent == \ - tmp.stage(C_global).iterator(i).range.extent - - s0, its1 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) - split_step1 = s0.transform_steps_size() - 1 - its = [] - for i0, i1 in zip(its0, its1): - its.append(i0) - its.append(i1) - s0 = s0.reorder(C, its) - for i in range(0, 5): - s0, _ = s0.fuse(C, [s0.stage(C).iterator(i), - s0.stage(C).iterator(i+1)]) - for level in range(0, 4): - tmp = s0 - tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, False) - assert tmp.stage(C).iterator(level+1).range.extent == \ - tmp.stage(C_global).iterator(0).range.extent - for level in range(0, 4): - tmp = s0 - tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, True) - assert tmp.stage(C).iterator(level+1).range.extent == \ - tmp.stage(C_global).iterator(1).range.extent - - -def test_rfactor(): - pass - - -def test_measure_local_builder_runner(): +def get_tiled_matmul(): dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) s0 = dag.get_init_state() A, B, C = 0, 1, 2 - s0, C_global = s0.cache_write(C, "global", dag) + C_global = s0.cache_write(C, "global", dag) C += 1 - s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 8, 8]) - s0, its1 = s0.split(C, s0.stage(C).iterator(4), [8, 4, 4]) - s0 = s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its0[3], its1[3]]) - s0 = s0.compute_at(C_global, C, s0.stage(C).iterator(3)) - s0, _ = s0.split(C_global, s0.stage(C_global).iterator(2), [16]) - s0, B_global = s0.cache_read(B, "global", [C_global], dag) + its0 = s0.split(C, s0.stages[C].iters[0], [4, 8, 8]) + its1 = s0.split(C, s0.stages[C].iters[4], [8, 4, 4]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3]]) + s0.compute_at(C_global, C, s0.stages[C].iters[3]) + s0.split(C_global, s0.stages[C_global].iters[2], [16]) + B_global = s0.cache_read(B, "global", [C_global], dag) C += 1 C_global += 1 - s0 = s0.compute_at(B_global, C_global, s0.stage(C_global).iterator(0)) - s0, A_global = s0.cache_read(A, "global", [C_global], dag) + s0.compute_at(B_global, C_global, s0.stages[C_global].iters[0]) + A_global = s0.cache_read(A, "global", [C_global], dag) B += 1 B_global += 1 C += 1 C_global += 1 - s0 = s0.compute_at(A_global, C_global, s0.stage(C_global).iterator(2)) - - tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) - - minp = ansor.MeasureInput(task, s0) - local_builder = ansor.LocalBuilder() - local_runner = ansor.LocalRunner() - - bress = local_builder.build([minp]) - assert bress[0].error_no == 0 - mress = local_runner.run([minp], bress) - assert mress[0].error_no == 0 - - -def test_search_basic(): - print("Test schedule search with default search policy") - - N = 128 - A, B, C = matmul_nkkm(N, N, N) - dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) - - # seed = random.randint(1, 1 << 30) - seed = 944563397 - log_file = "/tmp/_ansor_python_ut_test.json" - - random.seed(seed) - cost_model = ansor.RandomModel() - search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, - callbacks=[ansor.LogToFile(log_file)]) - state = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) - sch, args = dag.apply_steps_from_state(state) - - print("==== Get State ====") - print(state) - print("==== Get Python Code ====") - print(dag.print_python_code_from_state(state)) - - try: - print("==== Get Lowered Stmt ====") - print(tvm.lower(sch, args, simple_mode=True)) - mod = tvm.build(sch, args, tgt) - - ctx = tvm.context("llvm", 0) - a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - print("==== Verification passed ====") - except Exception: - raise Exception("Error encounterd with seed: %d" % (seed)) - - inp, res = ansor.best_measure_pair_in_file(log_file) - s0 = dag.infer_bound_from_state(state) - s1 = dag.infer_bound_from_state(inp.state) - assert str(s0) == str(s1) - - if os.path.isfile(log_file): - os.system("rm -rf %s" % log_file) - - -if __name__ == "__main__": - test_compute_dag_basic() - test_state_split_fuse_reorder() - test_state_compute_at_root_inline() - test_state_cache_read_write() - test_follow_split_follow_fused_split() - test_rfactor() - test_measure_local_builder_runner() - test_search_basic() + s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) + return dag, s0.state_object diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py new file mode 100644 index 000000000000..61eb0153a87c --- /dev/null +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test ComputeDAG (replay, infer bound)""" + +import tvm +from tvm import ansor, te + +from test_ansor_common import get_tiled_matmul + + +def test_apply_steps(): + dag, s = get_tiled_matmul() + dag.print_python_code_from_state(s) + sch, tensors = dag.apply_steps_from_state(s) + stmt = tvm.lower(sch, tensors, simple_mode=True) + + +def test_infer_bound(): + dag, s = get_tiled_matmul() + s = dag.infer_bound_from_state(s) + s = ansor.loop_state.State(s) + + A_global, B_global, C_global = 1, 3, 4 + assert s.stages[B_global].iters[0].range.extent == 512 + assert s.stages[B_global].iters[1].range.extent == 16 + assert s.stages[A_global].iters[0].range.extent == 1 + assert s.stages[A_global].iters[1].range.extent == 16 + assert s.stages[C_global].iters[0].range.extent == 64 + + +def test_lower_legalize_invalid_attach(): + N, M = 10, 10 + + A = te.compute((N, M), lambda i, j: 1.0, name='A') + B = te.compute((N, M), lambda i, j: A[i][j], name='B') + + dag = ansor.ComputeDAG([A, B]) + s = dag.get_init_state() + + A, B = 0, 1 + s.compute_at(A, B, s.stages[B].iters[1]) + s.split(B, s.stages[B].iters[1], [2]) + + sch, tensors = dag.apply_steps_from_state(s.state_object) + stmt = tvm.lower(sch, tensors, simple_mode=True) + + +if __name__ == "__main__": + test_apply_steps() + test_infer_bound() + test_lower_legalize_invalid_attach() diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py new file mode 100644 index 000000000000..34b720e7e1af --- /dev/null +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test loop state and schedule primitives""" + +from tvm import ansor, te +import topi + +from test_ansor_common import matmul_nkkm, conv2d_nchw_bn_relu + + +def test_state_split_fuse_reorder_annotation(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + i, j, k = s0.stages[C].iters + + assert i.range.extent == 512 + + io, ii = s0.split(C, i, [16]) + assert s0.stages[C].iters[0] == io + assert s0.stages[C].iters[1] == ii + assert io.range.extent == 32 + assert ii.range.extent == 16 + + jo, ji = s0.split(C, j, [8]) + assert jo.range.extent == 64 + assert ji.range.extent == 8 + + s0.reorder(C, [io, jo, k, ji, ii]) + assert s0.stages[C].iters[2].range.extent == 512 + + fused_it = s0.fuse(C, [io, jo]) + assert fused_it.range.extent == 2048 + + s1 = dag.get_init_state() + i, j, _ = s1.stages[C].iters + i1, i2, i3 = s1.split(C, i, [8, 2]) + j1, j2, j3 = s1.split(C, j, [32, 8], False) + assert s1.stages[C].iters[0].range.extent == 32 + assert s1.stages[C].iters[1].range.extent == 8 + assert s1.stages[C].iters[2].range.extent == 2 + assert s1.stages[C].iters[3].range.extent == 32 + assert s1.stages[C].iters[4].range.extent == 8 + assert s1.stages[C].iters[5].range.extent == 2 + + s1.parallel(C, j1) + s1.unroll(C, j2) + s1.vectorize(C, j3) + s1.bind_thread(C, i1, "blockIdx.x") + s1.bind_thread(C, i2, "vthread") + s1.bind_thread(C, i3, "threadIdx.y") + + +def test_follow_split_follow_fused_split(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + + C_global = s0.cache_write(C, "global", dag) + C += 1 + + its0 = s0.split(C, s0.stages[C].iters[0], [4, 2, 8, 4], True) + split_step0 = s0.transform_steps_size() - 1 + for level in range(1, 6): + tmp = s0.copy() + tmp.follow_split(C_global, tmp.stages[C_global].iters[0], split_step0, level) + for i in range(0, level): + assert tmp.stages[C].iters[i].range.extent == \ + tmp.stages[C_global].iters[i].range.extent + + its1 = s0.split(C, s0.stages[C].iters[5], [2, 2, 4, 8]) + split_step1 = s0.transform_steps_size() - 1 + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0.reorder(C, its) + for i in range(0, 5): + s0.fuse(C, [s0.stages[C].iters[i], s0.stages[C].iters[i + 1]]) + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + [split_step0, split_step1], level, False) + assert tmp.stages[C].iters[level + 1].range.extent == \ + tmp.stages[C_global].iters[0].range.extent + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + [split_step0, split_step1], level, True) + assert tmp.stages[C].iters[level + 1].range.extent == \ + tmp.stages[C_global].iters[1].range.extent + + +def test_state_compute_at_root_inline(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + # data, padding, kernel = 0, 1, 2 + conv = 3 + # bias = 4 + bias_add = 5 + # bn_scale = 6 + bn_mul = 7 + # bn_offset = 8 + bn_add, relu = 9, 10 + + s0 = dag.get_init_state() + s0.compute_inline(bn_add) + s0.compute_inline(bn_mul) + s0.compute_inline(bias_add) + s0.compute_at(conv, relu, s0.stages[relu].iters[2]) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + s0 = s0.compute_root(conv) + s0 = s0.compute_root(bn_mul) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + "for i (None)\n" + \ + " for j (None)\n" + \ + " for k (None)\n" + \ + " for l (None)\n" + \ + " Bn_mul = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + +def test_state_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + out = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, out]) + data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + + # 0: init state + s0 = dag.get_init_state() + ori_its = s0.stages[add].iters + its = s0.split(add, s0.stages[add].iters[0], [2]) + s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) + s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + conv_global = s0.cache_write(conv, "global", dag) + conv += 1 + relu += 1 + add += 1 + s0.compute_at(conv_global, conv, s0.stages[conv].iters[3]) + + # 2: simple cache_read with compute_at + kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0.compute_at(kernel_global, conv_global, + s0.stages[conv_global].iters[4]) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0.compute_at(pad_temp_global, conv_global, s0.stages[conv_global].iters[2]) + s0.compute_at(pad_temp_shared, conv_global, s0.stages[conv_global].iters[4]) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + data_global = s0.cache_read(data, "global", [pad_temp, add], dag) + pad_temp += 1 + pad_temp_global += 1 + pad_temp_shared += 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 5: cache_write with multi outputs + # TVM's cache_write actually has a bug with this case: + # + # After schedule.cache_write, TVM generate one new stage: + # From: kernel_data -> kernel_split -> kernel + # To: kernel_data -> kernel_split_global -> kernel_split -> kernel + # + # But with topo sort analyse, we get: + # // kernel_data -> kernel_split_global -> kernel_split -> kernel + # \ / + # ----------------> kernel_split ----------------> + # + # Seems there's bug with the input/output tensor. Such multi outputs case + # should be unusual, so we make some hack on DoCacheWrite + # To be fixed in the future + s0.cache_write(kernel_split, "global", dag) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0_c (0,512)\n" + \ + " for i1_c (0,512)\n" + \ + " for i2_c (0,3)\n" + \ + " for i3_c (0,3)\n" + \ + " Kernel_split.global = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + +def test_rfactor(): + dag = ansor.ComputeDAG(matmul_nkkm(8, 8, 512)) + s0 = dag.get_init_state() + C = 2 + + ko, ki = s0.split(C, s0.stages[C].iters[2], [16]) + + s1 = s0.copy() + s1.rfactor(C, ko, 2, dag) + assert str(s1) == \ + "Placeholder: A, B\n" + \ + "for i (0,8)\n" + \ + " for j (0,8)\n" + \ + " for k_o (0,32)\n" + \ + " for k_i (0,16)\n" + \ + " C.rf = ...\n" + \ + "for ax0 (0,8)\n" + \ + " for ax1 (0,8)\n" + \ + " for k_o_v (0,32)\n" + \ + " C.repl = ...\n" + + s2 = s0.copy() + s2.rfactor(C, ki, 2, dag) + assert str(s2) == \ + "Placeholder: A, B\n" + \ + "for i (0,8)\n" + \ + " for j (0,8)\n" + \ + " for k_i (0,16)\n" + \ + " for k_o (0,32)\n" + \ + " C.rf = ...\n" + \ + "for ax0 (0,8)\n" + \ + " for ax1 (0,8)\n" + \ + " for k_i_v (0,16)\n" + \ + " C.repl = ...\n" + + +if __name__ == "__main__": + test_state_split_fuse_reorder_annotation() + test_follow_split_follow_fused_split() + test_state_cache_read_write() + test_rfactor() diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py new file mode 100644 index 000000000000..baf8a0c4efa2 --- /dev/null +++ b/tests/python/unittest/test_ansor_measure.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test measurement and log serialization""" + +import tvm +from tvm import ansor +import tempfile + +from test_ansor_common import get_tiled_matmul + + +def test_serialization(): + dag, s = get_tiled_matmul() + target = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", target) + + inp = ansor.measure.MeasureInput(task, s) + res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) + + with tempfile.NamedTemporaryFile() as fp: + ansor.serialization.write_measure_records_to_file(fp.name, [inp], [res]) + + log_reader = ansor.serialization.LogReader(fp.name) + inputs, results = log_reader.read_lines() + assert len(inputs) == 1 + + s1 = dag.infer_bound_from_state(s) + s2 = dag.infer_bound_from_state(inputs[0].state) + + assert s1 == s2 + assert not (s1 == dag.get_init_state().state_object) + + +def test_measure_local_builder_runner(): + dag, s0 = get_tiled_matmul() + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + local_runner = ansor.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +if __name__ == "__main__": + test_serialization() + test_measure_local_builder_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py new file mode 100644 index 000000000000..eea3f5cfbda3 --- /dev/null +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import random +import os +import numpy as np +import tempfile + +import tvm +from tvm import ansor + +from test_ansor_common import matmul_nkkm + +def test_search_basic(): + print("Test schedule search with the default search policy") + + N = 128 + A, B, C = matmul_nkkm(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + seed = 944563397 + random.seed(seed) + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + cost_model = ansor.RandomModel() + search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + tune_option = ansor.TuneOption(n_trials=2, + callbacks=[ansor.LogToFile(log_file)]) + state = ansor.auto_schedule(task, search_policy, + tune_option=tune_option) + sch, args = dag.apply_steps_from_state(state) + + print("==== Get State ====") + print(state) + print("==== Get Python Code ====") + print(dag.print_python_code_from_state(state)) + + try: + print("==== Get Lowered Stmt ====") + print(tvm.lower(sch, args, simple_mode=True)) + mod = tvm.build(sch, args, tgt) + + ctx = tvm.context("llvm", 0) + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + print("==== Verification passed ====") + except Exception: + raise Exception("Error encountered with seed: %d" % (seed)) + + inp, res = ansor.best_measure_pair_in_file(log_file) + s0 = dag.infer_bound_from_state(state) + s1 = dag.infer_bound_from_state(inp.state) + assert s0 == s1 + + +if __name__ == "__main__": + test_search_basic() From 43d1530a253dc65aaf9f8da9cc818e9e0c4a1db0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 7 Jun 2020 23:30:18 -0700 Subject: [PATCH 10/78] fix unit tests --- tests/python/unittest/test_ansor_loop_state.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 34b720e7e1af..287a1b773395 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -23,7 +23,7 @@ from test_ansor_common import matmul_nkkm, conv2d_nchw_bn_relu -def test_state_split_fuse_reorder_annotation(): +def test_split_fuse_reorder_annotation(): dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) s0 = dag.get_init_state() C = 2 @@ -108,7 +108,7 @@ def test_follow_split_follow_fused_split(): tmp.stages[C_global].iters[1].range.extent -def test_state_compute_at_root_inline(): +def test_compute_at_root_inline(): dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) # data, padding, kernel = 0, 1, 2 @@ -144,8 +144,8 @@ def test_state_compute_at_root_inline(): " for i3 (0,112)\n" + \ " compute = ...\n" - s0 = s0.compute_root(conv) - s0 = s0.compute_root(bn_mul) + s0.compute_root(conv) + s0.compute_root(bn_mul) assert str(s0) == \ "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ "for i1 (0,3)\n" + \ @@ -171,7 +171,7 @@ def test_state_compute_at_root_inline(): " compute = ...\n" -def test_state_cache_read_write(): +def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( 1, 1), (1, 1) @@ -469,7 +469,8 @@ def test_rfactor(): if __name__ == "__main__": - test_state_split_fuse_reorder_annotation() + test_split_fuse_reorder_annotation() test_follow_split_follow_fused_split() - test_state_cache_read_write() + test_compute_at_root_inline() + test_cache_read_write() test_rfactor() From f367d1533a10c2d476b7a12e54c5261f71b08cfb Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 8 Jun 2020 14:36:42 +0800 Subject: [PATCH 11/78] Add RPCRunner & OpenCL/CUDA test (#12) * Add RPCRunner & OpenCL search test * Add CUDA search test * Add RPCRunner test --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/measure.py | 22 +++++++ python/tvm/rpc/server.py | 3 +- src/ansor/measure.cc | 8 +++ .../search_policy/meta_tile_rewrite_policy.h | 1 - tests/python/unittest/test_ansor_measure.py | 29 +++++++++ .../unittest/test_ansor_search_policy.py | 61 +++++++++++++++++-- 7 files changed, 117 insertions(+), 9 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 1be7ed404c17..7552878a3c50 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,6 +28,6 @@ from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner from .cost_model import RandomModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 5438edfaa6b2..b80de7c01633 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -168,6 +168,28 @@ def __init__(self, _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("ansor.RPCRunner") +class RPCRunner(Runner): + def __init__(self, key, host, port, priority=1, + n_parallel=1, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval) + + if check_remote(key, host, port, priority, timeout): + logger.info("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + + MAX_ERROR_MSG_LEN = 512 diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 15a3c7de789d..42bcb00a9117 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -348,7 +348,8 @@ def __init__(self, cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, - "--port=%s" % port] + "--port=%s" % port, + "--port-end=%s" % port_end] if tracker_addr: assert key cmd += ["--tracker=%s:%d" % tracker_addr, diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 43be530f2a35..e3593753d3ff 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -368,5 +368,13 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner") cooldown_interval); }); +TVM_REGISTER_GLOBAL("ansor.RPCRunner") +.set_body_typed([](const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { + return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 0c8c44b9c5ea..823ef6df4983 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -76,7 +76,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SearchTask cur_task_; // The current task - friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index baf8a0c4efa2..0385568894fe 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -19,6 +19,8 @@ import tvm from tvm import ansor +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server import tempfile from test_ansor_common import get_tiled_matmul @@ -62,6 +64,33 @@ def test_measure_local_builder_runner(): assert mress[0].error_no == 0 +def test_measure_local_builder_rpc_runner(): + dag, s0 = get_tiled_matmul() + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = rpc_runner.run([minp], bress) + assert mress[0].error_no == 0 + + tracker.terminate() + server.terminate() + + if __name__ == "__main__": test_serialization() test_measure_local_builder_runner() + test_measure_local_builder_rpc_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index eea3f5cfbda3..9a57691aba22 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -24,19 +24,20 @@ import tvm from tvm import ansor +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server from test_ansor_common import matmul_nkkm -def test_search_basic(): - print("Test schedule search with the default search policy") +def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'): + print("Test %s schedule search with the default search policy" % (target)) N = 128 A, B, C = matmul_nkkm(N, N, N) dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") + tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) - seed = 944563397 random.seed(seed) with tempfile.NamedTemporaryFile() as fp: @@ -44,7 +45,7 @@ def test_search_basic(): cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, + tune_option = ansor.TuneOption(n_trials=2, runner=runner, callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, tune_option=tune_option) @@ -60,7 +61,7 @@ def test_search_basic(): print(tvm.lower(sch, args, simple_mode=True)) mod = tvm.build(sch, args, tgt) - ctx = tvm.context("llvm", 0) + ctx = tvm.context(target, 0) a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) @@ -75,7 +76,55 @@ def test_search_basic(): s0 = dag.infer_bound_from_state(state) s1 = dag.infer_bound_from_state(inp.state) assert s0 == s1 + print() + + +def test_search_basic(): + search_common(seed=944563397) + + +def test_search_opencl(): + if tvm.context("opencl", 0).exist: + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + search_common("opencl", 380344973, rpc_runner) + + tracker.terminate() + server.terminate() + else: + print("OpenCL device not found, skip this test.") + + +def test_search_cuda(): + ctx = tvm.context("cuda", 0) + if ctx.exist: + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + search_common("cuda", 903667810, rpc_runner) + + tracker.terminate() + server.terminate() + else: + print("CUDA device not found, skip this test.") if __name__ == "__main__": test_search_basic() + test_search_opencl() + test_search_cuda() From 2bd6471d6cc3126bea111b373bbfc273dbf8e595 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 8 Jun 2020 01:10:27 -0700 Subject: [PATCH 12/78] rebase to upstream/master --- .gitignore | 1 + python/tvm/ansor/measure.py | 4 ++-- src/ansor/compute_dag.cc | 20 +++++++++---------- src/ansor/{feature.cc => feature.ccc} | 0 .../search_policy/meta_tile_rewrite_policy.cc | 4 ++-- .../search_policy/meta_tile_rewrite_policy.h | 4 ++-- src/ansor/search_policy/utils.h | 8 ++++---- .../python/unittest/test_ansor_compute_dag.py | 8 ++++++++ 8 files changed, 28 insertions(+), 21 deletions(-) rename src/ansor/{feature.cc => feature.ccc} (100%) diff --git a/.gitignore b/.gitignore index b9357018a64c..506e54d93067 100644 --- a/.gitignore +++ b/.gitignore @@ -196,6 +196,7 @@ tvm_t.* .python_history .pytest_cache .local +cmake-build-debug # Visual Studio Code .vscode diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b80de7c01633..e10da09e4b5a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -34,7 +34,7 @@ import tvm._ffi from tvm.runtime import Object, module, ndarray from tvm.driver import build_module -from tvm.target import build_config +from tvm.ir import transform from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel @@ -254,7 +254,7 @@ def timed_func(): dirname, "tmp_func." + build_func.output_format) try: - with build_config(unroll_max_extent=task.hardware_params.max_unroll_vec): + with transform.PassContext(): # todo(lmzheng): port the unroll pass func = build_module.build( sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index f3979ef0d259..de3b98a5106b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -129,17 +129,19 @@ class TensorAccessExtractor : public StmtExprVisitor { this->VisitExpr(expr); } - void VisitExpr_(const CallNode *op) final { - if (op->call_type == CallNode::CallType::Halide) { - buf_accesses[Downcast(op->func)].emplace_back( - op->args.begin(), op->args.end()); - } + void VisitExpr_(const CallNode* op) final { if (op->name == tir::intrinsic::tvm_if_then_else) { has_branch = true; } StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const ProducerLoadNode* op) final { + buf_accesses[Downcast(op->producer)->op].emplace_back( + op->indices.begin(), op->indices.end()); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const IfThenElseNode* op) final { has_branch = true; StmtExprVisitor::VisitStmt_(op); @@ -518,7 +520,7 @@ class FlopEstimator: public ExprFunctor { double VisitExpr_(const FloatImmNode* op) final { return 0.0; } double VisitExpr_(const IntImmNode* op) final { return 0.0; } -// double VisitExpr_(const UIntImm* op) final { return 0.0; } + double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; } double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } double VisitExpr_(const VarNode* op) final { return 0.0; } @@ -545,11 +547,6 @@ class FlopEstimator: public ExprFunctor { VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); double VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::CallType::Halide) { - // ignore flops in index expressions - return 0.0; - } - double ret = 0.0; for (const auto&x : op->args) { ret += VisitExpr(x); @@ -557,6 +554,7 @@ class FlopEstimator: public ExprFunctor { return ret; } + double VisitExprDefault_(const Object* op) final { fail = true; return -1.0; diff --git a/src/ansor/feature.cc b/src/ansor/feature.ccc similarity index 100% rename from src/ansor/feature.cc rename to src/ansor/feature.ccc diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index c22d890a8b51..86a7eba1da3a 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -47,7 +47,7 @@ TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode); const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, - Map params, + Map params, int seed) { auto node = make_object(); node->program_cost_model = std::move(program_cost_model); @@ -1440,7 +1440,7 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") .set_body_typed([](CostModel program_cost_model, - Map params, + Map params, int seed){ return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); }); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 823ef6df4983..f92813b11273 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -53,10 +53,10 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ - Map params; + Map params; static SearchPolicy make(CostModel program_cost_model, - Map params, + Map params, int seed); // Search and make n_trails measurements diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 607a549e1b8a..3d0611173c94 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -41,7 +41,7 @@ namespace tvm { namespace ansor { // Get an integer from a tvm str Map -inline int GetIntParam(const Map& attr_dict, +inline int GetIntParam(const Map& attr_dict, const std::string& key) { CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto pint = attr_dict[key].as(); @@ -50,7 +50,7 @@ inline int GetIntParam(const Map& attr_dict, } // Get a double from a tvm str Map -inline double GetDoubleParam(const Map& attr_dict, +inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto pdouble = attr_dict[key].as(); @@ -59,7 +59,7 @@ inline double GetDoubleParam(const Map& attr_dict, } // Get a string from a tvm str Map -inline std::string GetStringParam(const Map& attr_dict, +inline std::string GetStringParam(const Map& attr_dict, const std::string& key) { CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; @@ -73,7 +73,7 @@ inline std::string GetStringParam(const Map& attr_dict, } // Get a iterator name set from a tvm str Map -inline std::set GetIterNameSetParam(const Map& attr_dict, +inline std::set GetIterNameSetParam(const Map& attr_dict, const std::string& key) { std::set ret; CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 61eb0153a87c..b60136d4265f 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -43,6 +43,12 @@ def test_infer_bound(): assert s.stages[C_global].iters[0].range.extent == 64 +def test_estimate_flop(): + dag, s = get_tiled_matmul() + + assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 + + def test_lower_legalize_invalid_attach(): N, M = 10, 10 @@ -63,4 +69,6 @@ def test_lower_legalize_invalid_attach(): if __name__ == "__main__": test_apply_steps() test_infer_bound() + test_estimate_flop() test_lower_legalize_invalid_attach() + From c860f2c27f46733798c5deb488e5856f1d63d77c Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 8 Jun 2020 21:04:42 +0800 Subject: [PATCH 13/78] Add Ansor basic tutorial (#13) * Add basic tutorial --- docs/conf.py | 1 + tutorials/ansor/README.txt | 4 + tutorials/ansor/tune_simple_subgraph.py | 204 ++++++++++++++++++++++++ tutorials/autotvm/README.txt | 4 +- 4 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 tutorials/ansor/README.txt create mode 100644 tutorials/ansor/tune_simple_subgraph.py diff --git a/docs/conf.py b/docs/conf.py index 7ece63bd7aa8..5cbaab7f7b6d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -197,6 +197,7 @@ ['../tutorials/frontend', '../tutorials/language', '../tutorials/optimize', + '../tutorials/ansor', '../tutorials/autotvm', '../tutorials/dev', '../tutorials/topi', diff --git a/tutorials/ansor/README.txt b/tutorials/ansor/README.txt new file mode 100644 index 000000000000..85b6ba401dae --- /dev/null +++ b/tutorials/ansor/README.txt @@ -0,0 +1,4 @@ +.. _tutorial-ansor-auto-schedule: + +Ansor: Template Free Auto Scheduling +------------------------------------ diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py new file mode 100644 index 000000000000..8555d6163c32 --- /dev/null +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Writing compute expression and Using Ansor auto-scheduler +========================================================= +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_, \ + `Minmin Sun `_, \ + `Zhao Wu `_ + +This is an introduction tutorial to the auto-scheduler module in TVM. + +There are two steps in auto-scheduling. +The first step is defining the target task. +The second step is running a search algorithm to auto explore the schedule. +In this tutorial, you can learn how to perform these two steps in TVM. +The whole workflow is illustrated by a matrix multiplication with bias add example. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use Ansor package in TVM, we need to install some extra dependencies. +# This step (installing xgboost) can be skipped as it doesn't need XGBoost +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user psutil xgboost +# +# To make TVM run faster in tuning, it is recommended to use cython +# as FFI of TVM. In the root directory of TVM, execute +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user cython +# sudo make cython3 +# +# Now return to python code. Import packages. + +import random +import sys + +import numpy as np +import tvm +from tvm import te + +# the module is called `ansor` +from tvm import ansor + +###################################################################### +# Step 1: Define the target compute subgraph +# ------------------------------------------- +# In this section, we will write a deterministic TVM compute expression code +# to a compute subgraph. +# +# .. note:: Comparing to :ref:`tutorials-autotvm-sec` +# +# In Ansor, we do not need users to provide a schedule template, the only input +# is the compute expression writing by :code:`tvm.te` API or topi op API. +# +# Here is how we implement a matrix multiplication subgraph in TVM. + +# Matmul with bias add +def matmul_add(N, L, M, dtype): + A = te.placeholder((N, L), name='A', dtype=dtype) + B = te.placeholder((L, M), name='B', dtype=dtype) + C = te.placeholder((N, M), name='C', dtype=dtype) + + k = te.reduce_axis((0, L), name='k') + mul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), + name='Mul') + D = te.compute((N, M), lambda i, j: C[i, j] + mul[i, j], name='D') + + return [A, B, C, D] + +###################################################################### +# Step 2: Search through the schedule space +# ------------------------------------------ +# In step 1, we build the compute subgraph. +# The next step is to pick a cost model as well as a search policy and explore the +# possible schedule. +# +# Auto-scheduler in TVM +# ^^^^^^^^^^^^^^^^^^^^^ +# The job for the Ansor auto-scheduler can be described by following pseudo code +# +# .. code-block:: c +# +# ct = 0 +# while ct < max_number_of_trials: +# auto generate a batch of schedules +# measure this batch of schedules on real hardware and get results +# ct += batch_size +# +# When proposing the next batch of schedules, Ansor can take different cost models to +# guide the schedule generating process. +# +# * :any:`RandomModel`: Generate and take new schedule randomly +# * :any:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step +# +# XGBModel can explore more efficiently and find better schedules. + +################################################################ +# Begin tuning +# ^^^^^^^^^^^^ +# Here we continue our matrix multiplication example. +# +# The :code:`ansor.ComputeDAG` takes the Tensor list as input, and generates +# a dag structure. During which process, :code:`ansor.ComputeDAG` will +# do some analyzes with the target subgraph and the results will be used in +# search policy later. +# +# Then we create the :code:`tvm.target` and a tuning task. + +N, L, M = 64, 64, 64 +A, B, C, D = matmul_add(N, L, M, 'float32') +dag = ansor.ComputeDAG([A, B, C, D]) + +print(dag) +print(dag.access_analyzer) + +tgt = tvm.target.create("llvm") +task = ansor.SearchTask(dag, "test", tgt) + +################################################################ +# Next, we choose random model and create a default search policy: +# :code:`ansor.MetaTileRewritePolicy`. +# +# We only make 5 trials in this tutorial for demonstration. In practice, +# you can do more trials according to your time budget. +# The :code:`ansor.LogToFile` callback will log the tuning results into a +# log file, which can be used to get the best config later. +# +# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high +# performance schedule for the target subgraph automatically. + +log_file = "matmul_add.json" + +seed = 0 +random.seed(seed) +cost_model = ansor.RandomModel() +search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + +tune_option = ansor.TuneOption(n_trials=5, + callbacks=[ansor.LogToFile(log_file)]) + +state = ansor.auto_schedule(task, search_policy, + tune_option=tune_option) +print(state) + +######################################################################### +# Finally we apply the history best to be a TVM schedule. +# +# We can call the function :code:`apply_steps_from_state` directly using the returned +# :code:`state` structure. +# :code:`state` can also be used to print out the user friendly Python code on demand. +# +# And since we've record the runing results to file, we can also use the following +# code to reply the best schedule from the log file: +# .. code-block:: c +# +# inp, res = ansor.best_measure_pair_in_file(log_file) +# state = inp.state +# s, arg_bufs = dag.apply_steps_from_state(state) +# +# With the :code:`state` above, we have lowered result and its python code: + +s, arg_bufs = dag.apply_steps_from_state(state) +print("==== Get Lowered Stmt ====") +print(tvm.lower(s, arg_bufs, simple_mode=True)) +print("==== Get Python Code ====") +print(dag.print_python_code_from_state(state)) + +######################################################################### +# Check the correctness to make sure we generate a right schedule. + +func = tvm.build(s, arg_bufs) + +# check correctness +a_np = np.random.uniform(size=(N, L)).astype(np.float32) +b_np = np.random.uniform(size=(L, M)).astype(np.float32) +c_np = np.random.uniform(size=(N, M)).astype(np.float32) +d_np = a_np.dot(b_np) + c_np + +d_tvm = tvm.nd.empty(d_np.shape) +func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm) + +tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-2) diff --git a/tutorials/autotvm/README.txt b/tutorials/autotvm/README.txt index 38e3b3343f4e..4ad36c000e3c 100644 --- a/tutorials/autotvm/README.txt +++ b/tutorials/autotvm/README.txt @@ -1,4 +1,4 @@ .. _tutorials-autotvm-sec: -Auto tuning ------------ +AutoTVM: Template Based Auto Tuning +----------------------------------- From f60d1a60dc96ac408625a85f34eea099c78dd8eb Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 8 Jun 2020 06:28:31 -0700 Subject: [PATCH 14/78] migrate feature extraction (#14) --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/feature.py | 147 +++++++ src/ansor/{feature.ccc => feature.cc} | 401 ++++++++++++++------ src/ansor/feature.h | 35 +- tests/python/unittest/test_ansor_feature.py | 97 +++++ 5 files changed, 566 insertions(+), 117 deletions(-) create mode 100644 python/tvm/ansor/feature.py rename src/ansor/{feature.ccc => feature.cc} (79%) create mode 100644 tests/python/unittest/test_ansor_feature.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 7552878a3c50..3e9b76c2f6ad 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-import, redefined-builtin -"""Namespace for Ansor autoSchedule""" +"""Namespace for Ansor auto-scheduler""" from . import compute_dag from . import measure @@ -23,6 +23,7 @@ from . import loop_state from . import task from . import utils +from . import feature # Shortcut from .compute_dag import ComputeDAG diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py new file mode 100644 index 000000000000..fb5fadf16296 --- /dev/null +++ b/python/tvm/ansor/feature.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""""Python API for Feature extraction. +The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` +""" + +from typing import List, Tuple +import struct +import numpy as np + +from .loop_state import StateObject +from .task import SearchTask +from .measure import MeasureInput, MeasureResult +from . import _ffi_api + + +DEFAULT_MAX_N_BUFS = 5 + +DEFAULT_FEATURE_VEC_LEN = 164 + + +def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Unpack the encoded feature (in byte array format) of from c++""" + size_of_int = 4 + size_of_float = 4 + + """ + The format for n records is: + { + int n; + int[n+2] sizes + + float[sizes[0]] feature for record 1 + float[sizes[1]] feature for record 2 + ... feature for record i... + float[sizes[n-1]] feature for record n + + float[sizes[n]] normalized throughput for n records + int[sizes[n+1]] task id for n records + } + """ + vec_len = DEFAULT_FEATURE_VEC_LEN + + # unpack sizes + offset = 0 + n = struct.unpack_from("1i", byte_arr, offset=offset)[0] + offset += size_of_int + + sizes = struct.unpack_from("%di" % (n+2), byte_arr, offset=offset) + offset += size_of_int * (n+2) + + # unpack features + features = [] + for size in sizes[:-2]: + row = [] + + """ + Now we need to unpack the feature for multiple statements. + The format is: + { + int n_stmts + float[n_stmt][vec_len] feature_vecs + } + where vec_len can be calculated by `(size - 1) / n_stmts` + """ + if size == 0: + # failed during lowering + features.append(np.zeros((1, vec_len))) + else: + n_stmts = struct.unpack_from("f", byte_arr, offset=offset) + offset += size_of_float + + n_stmts = int(n_stmts[0] + 0.5) + tmp_vec_len = (size - 1) // n_stmts + assert tmp_vec_len == vec_len, "The lenght of feature vector is wrong. " \ + "Expected %d but got %d." % (vec_len, tmp_vec_len) + assert (size - 1) % n_stmts == 0 + for _ in range(n_stmts): + x = struct.unpack_from("%df" % vec_len, byte_arr, offset=offset) + offset += vec_len * size_of_float + row.append(x) + + features.append(np.array(row)) + + # unpack normalized_throughputs + m = sizes[-2] + normalized_throughputs = struct.unpack_from("%df" % m, byte_arr, offset=offset) + offset += m * size_of_int + + # unpack task_ids + m = sizes[-1] + task_ids = struct.unpack_from("%di" % m, byte_arr, offset=offset) + offset += m * size_of_int + + assert offset == len(byte_arr), "%d vs %d" % (offset, len(byte_arr)) + return np.array(features), np.array(normalized_throughputs), np.array(task_ids) + + +def get_per_stmt_features_from_file(filename: str, + n_lines: int, + max_n_bufs: int = None) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get per_stmt features from a log file""" + byte_arr = _ffi_api.GetPerStmtFeaturesFromFile( + filename, n_lines, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr) + + +def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], + results: List[MeasureResult], + skip_first_n_feature_extraction: int = 0, + max_n_bufs: int = None,) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get per_stmt features from measurement pairs""" + byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( + inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr) + + +def get_per_stmt_features_from_states(states: List[StateObject], + task: SearchTask, + max_n_bufs: int = None) -> List[np.ndarray]: + """Get per_stmt features from states""" + byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( + states, task, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr)[0] + + +def get_per_stmt_feature_names(max_n_bufs: int = None) -> List[str]: + """Get names of the elements in the flatten feature vector""" + return [x for x in + _ffi_api.GetPerStmtFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)] diff --git a/src/ansor/feature.ccc b/src/ansor/feature.cc similarity index 79% rename from src/ansor/feature.ccc rename to src/ansor/feature.cc index 31afe931361c..16ddb73ebf47 100644 --- a/src/ansor/feature.ccc +++ b/src/ansor/feature.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/feature.cc + * \brief Feature extraction for the cost model */ #include @@ -7,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -15,16 +36,12 @@ #include "measure.h" #include "serialization.h" #include "utils.h" -// #include "../arithmetic/compute_expr.h" namespace tvm { -/* Import the function from build_module.cc */ -extern void GetBinds(const Array& args, - bool compact, - const std::unordered_map& binds, - Map* out_binds, - Array* out_arg_list, - const BuildConfig& config); +/* Import the function from driver_api.cc */ +extern void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list); } // namespace tvm @@ -35,6 +52,9 @@ using namespace tvm::tir; using arith::ConstIntBound; using arith::Analyzer; +template +using BufferMap = std::unordered_map; + static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; // Annotation position encoding @@ -61,7 +81,7 @@ enum ReuseType { // Feature for an access of a buffer struct BufferAccessFeature { - std::string tensor_name; + std::string buffer_name; BufferAccessType acc_type; float bytes; float unique_bytes; @@ -169,10 +189,11 @@ AnnotationPosType GetAnnotationPosEncoding( } if (find_ct == 0) { - // If not find in spatial args, then it is a reduce iteartor. + // If not find in spacial args, then it is a reduce iterator. // Use name to match + const std::string& var_name = var->name_hint; for (size_t i = 0; i < reduce_axis.size(); ++i) { - if (var->name_hint.find(reduce_axis[i]->var->name_hint) != std::string::npos) { + if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { find_i = i; find_ct++; } @@ -238,7 +259,6 @@ class MathOpCounter : public StmtExprVisitor { void VisitExpr_(const NotNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const SelectNode* op) final { select_op++; StmtExprVisitor::VisitExpr_(op); } - // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode void VisitExpr_(const CallNode* op) final { if (op->call_type == CallNode::CallType::PureIntrinsic) { if (op->dtype.is_float()) { @@ -246,8 +266,8 @@ class MathOpCounter : public StmtExprVisitor { } else { int_math_func++; } - } else if (op->call_type != CallNode::CallType::Halide) { - if (op->dtype.is_float()) { + } else { + if (op->dtype.is_float()) { float_other_func++; } else { int_other_func++; @@ -272,42 +292,38 @@ class BufferAccessExtractor : public StmtExprVisitor { this->VisitExpr(expr); } - void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, + void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array& indices) { - BufferAccess& acc = buf_accesses[ten]; + BufferAccess& acc = buf_accesses[buf]; acc.acc_type = acc_type; acc.indices.push_back(std::vector(indices.begin(), indices.end())); } - // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode - void VisitExpr_(const CallNode *op) final { - if (op->call_type == CallNode::CallType::Halide) { - te::Tensor ten = Downcast(op->func).output(op->value_index); - BufferAccess& acc = buf_accesses[ten]; - switch (acc.acc_type) { - case kRead: - break; - case kWrite: - acc.acc_type = kReadWrite; break; - case kReadWrite: - break; - case kUnknownRW: - default: - acc.acc_type = kRead; break; - } + void VisitExpr_(const BufferLoadNode *op) final { + BufferAccess& acc = buf_accesses[op->buffer]; + switch (acc.acc_type) { + case kRead: + break; + case kWrite: + acc.acc_type = kReadWrite; break; + case kReadWrite: + break; + case kUnknownRW: + default: + acc.acc_type = kRead; break; + } - if (acc.acc_type != kReadWrite) { - // If a buffer is both read and written, in the tvm DSL, it must be a update, - // so the indices should be the same. Then we can skip appending indices for it. - // Otherwise we do the following. - buf_accesses[ten].indices.push_back( - std::vector(op->args.begin(), op->args.end())); - } + if (acc.acc_type != kReadWrite) { + // If a buffer is both read and written, in the tvm DSL, it must be a update, + // so the indices should be the same. Then we can skip appending indices for it. + // Otherwise we do the following. + buf_accesses[op->buffer].indices.push_back( + std::vector(op->indices.begin(), op->indices.end())); } StmtExprVisitor::VisitExpr_(op); } - std::unordered_map buf_accesses; + BufferMap buf_accesses; }; // Compute coefficient for an loop iterator in an expression @@ -430,11 +446,11 @@ void ComputeRegion( // Compute reuse distance and reuse ratio for accesses to a buffer // return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct std::tuple ComputeReuse( - const te::Tensor& t, + const Buffer& buf, const std::vector >& indices, const std::vector& for_loop_stack, - const std::unordered_map > > >& for_touch_regions) { + const std::unordered_map > > >& for_touch_regions) { float reuse_dis_iter = 1.0f; float reuse_dis_bytes = -1.0f; @@ -479,16 +495,16 @@ std::tuple ComputeReuse( return std::make_tuple(kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); } - const std::unordered_map > >& - tensor_map = for_touch_regions.at(cur_for); + const BufferMap > >& buffer_map + = for_touch_regions.at(cur_for); - int serial_reuse = static_cast(tensor_map.at(t).size()) - 1; + int serial_reuse = static_cast(buffer_map.at(buf).size()) - 1; if (serial_reuse > 0) { int64_t extent = GetIntImm(cur_for->extent); // Have SerialMultipleReadWrite reuse reuse_dis_iter = std::numeric_limits::max(); - for (const auto& acc_info : tensor_map.at(t)) { + for (const auto& acc_info : buffer_map.at(buf)) { reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); } @@ -600,13 +616,8 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { } } - // TODO(...): ProvideNode is deprecated, move to BufferStoreNode - void VisitStmt_(const ProvideNode* node) final { - te::Operation op = Downcast(node->func); - te::Tensor ten = op.output(node->value_index); - const te::ComputeOpNode* pcompute = op.as(); - - FeatureSet &fea = op_features[ten]; + void VisitStmt_(const BufferStoreNode* node) final { + FeatureSet &fea = buffer_features[node->buffer]; // compute feature MathOpCounter mathops; @@ -641,8 +652,10 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { for (const ForNode* pfor : vec_for_stack) { fea.vec_prod *= GetIntImm(pfor->extent); } - fea.vec_type = GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, - node->args, pcompute->axis, pcompute->reduce_axis); + fea.vec_type = kPosMixed; + // todo(lmzheng): this feature requires operation (tvm.compute) information + //GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, + //node->args, pcompute->axis, pcompute->reduce_axis); } fea.unroll_num = unroll_for_stack.size(); @@ -652,8 +665,9 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { for (const ForNode* pfor : unroll_for_stack) { fea.unroll_prod *= GetIntImm(pfor->extent); } - fea.unroll_type = GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, - node->args, pcompute->axis, pcompute->reduce_axis); + fea.unroll_type = kPosMixed; + //GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, + //node->args, pcompute->axis, pcompute->reduce_axis); } fea.parallel_num = parallel_for_stack.size(); @@ -663,8 +677,9 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { for (const ForNode* pfor : parallel_for_stack) { fea.parallel_prod *= GetIntImm(pfor->extent); } - fea.parallel_type = GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, - node->args, pcompute->axis, pcompute->reduce_axis); + fea.parallel_type = kPosMixed; + //GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, + //node->args, pcompute->axis, pcompute->reduce_axis); } // GPU threads @@ -680,13 +695,13 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { // Extract all buffer access std::vector acc_feas; BufferAccessExtractor buf_extractor; - buf_extractor.InsertAccess(ten, kWrite, node->args); + buf_extractor.InsertAccess(node->buffer, kWrite, node->indices); buf_extractor.ExtractReads(node->value); // Compute touched region for all outer loops Analyzer ana; for (auto x : for_loop_stack) { - ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1)); + ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1), true); } std::vector mem_bytes_list; @@ -704,22 +719,22 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { const ForNode* p_for = for_loop_stack[i]; ana.Bind(p_for->loop_var, - Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent)); + Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent), true); // Note, here we do overwrite. // So if there are multiple Provides, the last one will overwrite the first few. // e.g. The update part in gemm will overwrite the init part. - std::unordered_map > >& - tensor_regions_map = for_touch_regions[p_for]; + BufferMap > >& + buffer_regions_map = for_touch_regions[p_for]; int64_t mem_bytes = 0; for (const auto &x : buf_extractor.buf_accesses) { - const te::Tensor& t = x.first; + const Buffer& t = x.first; const BufferAccess& acc = x.second; ComputeRegion(acc.indices, &ana, &tmp_region); int64_t touched_size = ElementProduct(tmp_region); - tensor_regions_map[t].push_back(std::make_tuple(acc.acc_type, + buffer_regions_map[t].push_back(std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes())); mem_bytes += touched_size * t->dtype.bytes(); } @@ -759,7 +774,7 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { // Compute buffer access feature for (const auto &x : buf_extractor.buf_accesses) { - const te::Tensor& t = x.first; + const Buffer& t = x.first; const BufferAccess& acc = x.second; std::vector int_shape; @@ -826,7 +841,7 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { acc_feas.emplace_back(); BufferAccessFeature& acc_fea = acc_feas.back(); - acc_fea.tensor_name = t->op->func_name(); + acc_fea.buffer_name = t->name; acc_fea.acc_type = acc.acc_type; acc_fea.stride = stride; acc_fea.bytes = bytes; @@ -854,21 +869,17 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.access_feas = acc_feas; } - // TODO(...): RealizeNode is deprecated, move to BufferRealizeNode - void VisitStmt_(const RealizeNode *node) final { + void VisitStmt_(const BufferRealizeNode *node) final { StmtExprVisitor::VisitStmt_(node); - te::Operation op = Downcast(node->func); - te::Tensor ten = op.output(node->value_index); - - FeatureSet& fea = op_features[ten]; + FeatureSet& fea = buffer_features[node->buffer]; float allocation_size = 1.0f; for (const auto& x : node->bounds) { allocation_size *= GetIntImm(x->extent); } // allocation feature - fea.alloc_size = allocation_size * ten->dtype.bytes(); + fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); fea.alloc_prod = allocation_size * outer_loop_prod; fea.alloc_outer_prod = outer_loop_prod; fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod; @@ -891,12 +902,12 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { int vthread_len{1}; int16_t cur_auto_unroll_max_step{0}; - std::unordered_map op_features; + BufferMap buffer_features; - // for a loop, for all its touched tensors, for all different accesses to the tensors, + // for a loop, for all its touched buffers, for all different accesses to the buffers, // its (access type, number of touched elements, number of bytes of single element) - std::unordered_map > > > for_touch_regions; + std::unordered_map > > > for_touch_regions; private: const int cache_line_size_ = 64; @@ -913,14 +924,12 @@ void GetPerStmtFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, std::vector* ret) { - LOG(WARNING) << "RealizeNode & ProvideNode deprecated, " - << "need to fix the implementation of PerStmtFeatureExtractor."; PerStmtFeatureExtractor extractor(cache_line_size); extractor(stmt); - ret->push_back(extractor.op_features.size()); + ret->push_back(extractor.buffer_features.size()); - for (const auto& x : extractor.op_features) { + for (const auto& x : extractor.buffer_features) { const FeatureSet& fea_set = x.second; /***** compute feature *****/ @@ -1148,33 +1157,49 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs, std::vector* feature, std::atomic* error_ct) { te::Schedule sch; Array tensors; - Map bounds; - GlobalVar g("main"); std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); sch = sch.normalize(); - bounds = te::InferBound(sch); + auto bounds = te::InferBound(sch); try { auto stmt = te::ScheduleOps(sch, bounds, false); Map out_binds; Array out_arg_list; bool compact = te::VerifyCompactBuffer(stmt); + const std::string& name = "main"; + GlobalVar global_var(name); + + // Copied from driver_api.cc::lower + auto pass_ctx = tvm::transform::PassContext::Current(); GetBinds(tensors, compact, std::unordered_map(), - &out_binds, &out_arg_list, BuildConfig::Create()); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, - std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String("main")); - auto mod = IRModule(Map({{g, f}})); - auto pass_list = Array(); + &out_binds, &out_arg_list); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + bool disable_vectorize = + pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); + bool instrument_bound_checkers = + pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + auto mod = IRModule(Map({{global_var, f}})); + if (task->target->device_type == kDLGPU) { + auto pass_list = Array(); + // Phase 0 pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64)); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + // Phase 1 + pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop()); + pass_list.push_back(tir::transform::VectorizeLoop(disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::Simplify()); - tvm::Map gpu_params { + tvm::Map gpu_params { {"max_shared_memory_per_block", task->hardware_params->max_shared_memory_per_block}, {"max_local_memory_per_block", @@ -1188,11 +1213,9 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, const auto& optimize = tir::transform::Sequential(pass_list); optimize(mod); } - pass_list.clear(); - pass_list.push_back(tir::transform::Simplify()); - const auto& optimize = tir::transform::Sequential(pass_list); + const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(g); + const auto& it = mod->functions.find(global_var); CHECK(it != mod->functions.end()); const auto& prim_func = (*it).second.as(); GetPerStmtFeature(prim_func->body, @@ -1205,8 +1228,8 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, void GetPerStmtFeaturesFromStates(const Array& states, const SearchTask& task, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features) { // extract features features->assign(states.size(), std::vector()); @@ -1230,8 +1253,8 @@ void GetPerStmtFeaturesFromStates(const Array& states, void GetPerStmtFeaturesFromStates(const Array& states, const std::vector& tasks, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features) { // extract features features->assign(states.size(), std::vector()); @@ -1314,13 +1337,13 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; } - GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, 0, features); + GetPerStmtFeaturesFromStates(states, tasks, 0, max_n_bufs, features); } void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, const Array& results, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features, std::vector* normalized_throughputs, std::vector* task_ids) { @@ -1379,9 +1402,173 @@ void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; } - GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, - skip_first_n_feature_extraction, features); + GetPerStmtFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, + max_n_bufs, features); } +TVMByteArray SerializeFeatures(std::vector >&& features, + std::vector&& normalized_throughputs, + std::vector&& task_ids, + std::vector* out_data) { + size_t total_bytes = 0; + std::vector size_vector; + + int n = features.size(); + + // serialize sizes + size_t size_vector_size = 1 + n + 2; + total_bytes += size_vector_size * sizeof(int); + + size_vector.reserve(size_vector_size); + size_vector.push_back(features.size()); + for (const auto& x : features) { + size_vector.push_back(static_cast(x.size())); + total_bytes += sizeof(float) * x.size(); + } + size_vector.push_back(static_cast(normalized_throughputs.size())); + total_bytes += sizeof(float) * normalized_throughputs.size(); + size_vector.push_back(static_cast(task_ids.size())); + total_bytes += sizeof(int) * task_ids.size(); + + CHECK_EQ(size_vector.size(), size_vector_size); + + // allocate memory + out_data->reserve(total_bytes); + char* ptr = out_data->data(); + + // serialize size_vector + memmove(ptr, reinterpret_cast(size_vector.data()), size_vector.size() * sizeof(int)); + ptr += size_vector.size() * sizeof(int); + + // serialize features + for (auto& x : features) { + memmove(ptr, x.data(), sizeof(float) * x.size()); + ptr += sizeof(float) * x.size(); + x.clear(); + } + + // serialize normalized_throughputs + memmove(ptr, reinterpret_cast(normalized_throughputs.data()), + normalized_throughputs.size() * sizeof(int)); + ptr += normalized_throughputs.size() * sizeof(int); + + // serialize task_ids + memmove(ptr, reinterpret_cast(task_ids.data()), task_ids.size() * sizeof(int)); + ptr += task_ids.size() * sizeof(int); + + CHECK_EQ(ptr - out_data->data(), total_bytes); + + return TVMByteArray{out_data->data(), total_bytes}; +} + + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + int n_lines = args[1]; + int max_n_bufs = args[2]; + + std::vector > features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStmtFeaturesFromFile(filename, n_lines, max_n_bufs, + &features, &normalized_throughputs, &task_ids); + + // serialization format for n records: + // + // int n; + // int[n+2] sizes + // + // float[sizes[0]] feature for record 1 + // float[sizes[1]] feature for record 2 + // ... feature for record i... + // float[sizes[n-1]] feature for record n + // + // float[sizes[n]] normalized throughput for n records + // int[sizes[n+1]] task id for n records + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); +}); + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromMeasurePairs") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array inputs = args[0]; + Array results = args[1]; + int skip_first_n_feature_extraction = args[2]; + int max_n_bufs = args[3]; + + std::vector > features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStmtFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, max_n_bufs, + &features, &normalized_throughputs, &task_ids); + + // serialization format for n records: + // + // int n; + // int[n+2] sizes + // + // float[sizes[0]] feature for record 1 + // float[sizes[1]] feature for record 2 + // ... feature for record i... + // float[sizes[n-1]] feature for record n + // + // float[sizes[n]] normalized throughput for n records + // int[sizes[n+1]] task id for n records + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); +}); + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromStates") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array states = args[0]; + SearchTask task = args[1]; + int max_n_bufs = args[2]; + + std::vector > features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStmtFeaturesFromStates(states, task, 0, max_n_bufs, &features); + + // serialization format for n records: + // + // int n; + // int[n+2] sizes + // + // float[sizes[0]] feature for record 1 + // float[sizes[1]] feature for record 2 + // ... feature for record i... + // float[sizes[n-1]] feature for record n + // + // float[sizes[n]] normalized throughput for n records + // int[sizes[n+1]] task id for n records + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); +}); + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeatureNames") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int max_n_bufs = args[0]; + std::vector names; + + GetPerStmtFeatureName(max_n_bufs, &names); + + Array arr; + for (const auto& x : names) { + arr.push_back(x); + } + *ret = arr; +}); + + } // namespace ansor } // namespace tvm diff --git a/src/ansor/feature.h b/src/ansor/feature.h index 149c59e8cb7d..e507149643e2 100644 --- a/src/ansor/feature.h +++ b/src/ansor/feature.h @@ -1,13 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_task.h - * \brief Meta inforamtion for a search task + * \file ansor/feature.h + * \brief Feature extraction for the cost model */ #ifndef TVM_ANSOR_FEATURE_H_ #define TVM_ANSOR_FEATURE_H_ -// #include #include #include #include "compute_dag.h" @@ -26,18 +43,18 @@ void GetPerStmtFeature(const Stmt& stmt, void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret); -/*! \brief Get PerStmt feature from states */ +/*! \brief Get PerStmt feature from states and the same task */ void GetPerStmtFeaturesFromStates(const Array& states, const SearchTask& task, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features); -/*! \brief Get PerStmt feature from states */ +/*! \brief Get PerStmt feature from states and different tasks */ void GetPerStmtFeaturesFromStates(const Array& states, const std::vector& tasks, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features); /*! \brief Get PerStmt feature from a log file */ @@ -51,8 +68,8 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, /*! \brief Get PerStmt feature from measure pairs */ void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, const Array& results, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features, std::vector* normalized_throughputs, std::vector* task_ids); diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py new file mode 100644 index 000000000000..abd304a9c2d7 --- /dev/null +++ b/tests/python/unittest/test_ansor_feature.py @@ -0,0 +1,97 @@ +"""Test feature extraction""" + +import math +import tempfile + +import tvm +from tvm import te, ansor + +from test_ansor_common import matmul_nkkm + + +def fequal(a, b): + return math.fabs(a - b) < 1e-6 + + +def test_cpu_matmul(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s = dag.get_init_state() + C = 2 + + i, j, k = s.stages[C].iters + io, ii = s.split(C, i, [16]) + jo, ji = s.split(C, j, [8]) + s.reorder(C, [io, jo, k, ji, ii]) + s.vectorize(C, ji) + s.parallel(C, io) + s.parallel(C, jo) + s.unroll(2, k) + + target = tvm.target.create('llvm') + task = ansor.SearchTask(dag, "test", target) + names = ansor.feature.get_per_stmt_feature_names() + fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + + stage_0 = fea[0] + assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) + fea_dict = {} + for name, value in zip(names, stage_0): + fea_dict[name] = value + + for name in ["B0", "B1", "B2"]: + if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): + c_name = name + if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): + if fequal(fea_dict[name + ".stride"], 0.0): + b_name = name + else: + a_name = name + + assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) + assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) + assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) + assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) + assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) + + assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) + # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) + assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) + assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) + assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1)) + + +def test_cpu_fusion(): + def fusion_test(N, M): + A = te.placeholder((N, M), name='A') + B = te.compute((N, M), lambda i, j: A[i][j], name='B') + C = te.compute((N, M), lambda i, j: B[i][j], name='C') + return [A, B, C] + + dag = ansor.ComputeDAG(fusion_test(64, 32)) + s = dag.get_init_state() + s.compute_at(1, 2, s.stages[2].iters[1]) + + target = tvm.target.create('llvm') + task = ansor.SearchTask(dag, "test", target) + names = ansor.feature.get_per_stmt_feature_names() + fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + + found = False + for stage_fea in fea: + for i, (name, value) in enumerate(zip(names, stage_fea)): + if 'reuse_type.kSerialMultipleReadWrite' in name and value > 0.5: + assert fequal(stage_fea[i + 2], 1.0) + assert fequal(stage_fea[i + 3], math.log2(16 + 1)) + found = True + assert found + + +def test_gpu_feature(): + # todo(lmzheng) + pass + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_fusion() + test_gpu_feature() From b839c0f6b8c45f4dcfdd96a7a60338b40387c5d4 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 9 Jun 2020 13:55:15 +0800 Subject: [PATCH 15/78] Add XGBModel & RPCRunnerWarpper (#15) * Add XGBModel & RPCRunnerWarpper * Revert "Add Parallel Granularity Mutation" --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/cost_model/cost_model.py | 29 ++ python/tvm/ansor/cost_model/xgb_model.py | 476 ++++++++++++++++++ python/tvm/ansor/measure.py | 48 ++ src/ansor/cost_model/cost_model.cc | 29 +- src/ansor/cost_model/cost_model.h | 6 +- .../search_policy/meta_tile_rewrite_policy.cc | 27 +- .../unittest/test_ansor_search_policy.py | 53 +- 8 files changed, 607 insertions(+), 64 deletions(-) create mode 100644 python/tvm/ansor/cost_model/xgb_model.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 3e9b76c2f6ad..2d27995e328e 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,6 +29,7 @@ from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper from .cost_model import RandomModel +from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index a0e586d69cec..fd9b67927185 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -42,3 +42,32 @@ def random_number(n, return_ptr): return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = np.random.uniform(0, 1, (n,)) + +@tvm._ffi.register_object("ansor.PythonBasedModel") +class PythonBasedModel(CostModel): + def __init__(self): + def update_func(inputs, results): + self.update(inputs, results) + + def predict_func(task, states, return_ptr): + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(len(states),)) + array_wrapper[:] = self.predict(task, states) + + def predict_stage_func(task, states, return_ptr): + ret = self.predict_stages(task, states) + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape) + array_wrapper[:] = ret + + self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func, + predict_func, predict_stage_func) + + def update(self, inputs, results): + raise NotImplementedError + + def predict(self, task, states): + raise NotImplementedError + + def predict_stages(self, task, states): + raise NotImplementedError diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py new file mode 100644 index 000000000000..e61acfbd168f --- /dev/null +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -0,0 +1,476 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Cost model based on xgboost""" +from typing import List +import multiprocessing +import logging +import time +from collections import defaultdict + +import numpy as np +import xgboost as xgb + +from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve +from .cost_model import PythonBasedModel +from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states +from ..serialization import LogReader + +logger = logging.getLogger('ansor') + +class XGBDMatrixContext: + """Context to hold additional attributes of xgb.DMatrix""" + def __init__(self): + self.context_dict = defaultdict(dict) + + def get(self, key, matrix, default=None): + return self.context_dict[key].get(matrix.handle.value, default) + + def put(self, key, matrix, value): + self.context_dict[key][matrix.handle.value] = value + +dmatrix_context = XGBDMatrixContext() + +class XGBModel(PythonBasedModel): + """Train a XGBoost model to predict the runtime cost of a program. + The cost of a program = the sum of the costs of all stages in this program. + i.e. Cost(p) = cost_s0 + cost_s1 + ... + cost_sn, where cost_si is the cost of Stage i + + The xgboost model makes prediction per stage, then we sum them up. + The final predction made by this class is normalized throughtput (from 0 to 1, larger is better) + + To support this stage decomposition, we have to implement a custom loss function for + XGBoost, which is the `pack_sum` in the code below. + """ + def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): + self.xgb_params = { + 'max_depth': 10, + 'gamma': 0.001, + 'min_child_weight': 0, + 'eta': 0.2, + # todo(lmzheng): automatically decrease learning rate when the loss is too large + + 'n_gpus': 0, + 'n_threads': multiprocessing.cpu_count() / 2, + 'silent': 0, + 'seed': seed or 43, + 'disable_default_eval_metric': 1 + } + self.bst = None + self.plan_size = 32 + self.num_warmup_sample = num_warmup_sample + self.verbose_eval = verbose_eval + + super().__init__() + + # measurement input/result pairs + self.inputs = [] + self.results = [] + self.inputs_feature_cache = [] + + def update(self, inputs, results): + if len(inputs) <= 0: + return + + self.inputs.extend(inputs) + self.results.extend(results) + + # extract feature + n_cached = len(self.inputs_feature_cache) + features, normalized_throughputs, task_ids = \ + get_per_stmt_features_from_measure_pairs(self.inputs, self.results, + skip_first_n_feature_extraction=n_cached) + if n_cached > 0: + features = list(features) + features[:n_cached] = self.inputs_feature_cache + features = np.array(features) + self.inputs_feature_cache = features + dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, task_ids, normalized_throughputs) + + # train xgb model + self.bst = xgb.train(self.xgb_params, dtrain, + num_boost_round=10000, + obj=pack_sum_square_error, + callbacks=[custom_callback( + stopping_rounds=50, + metric='tr-p-rmse', + fevals=[ + pack_sum_rmse, pack_sum_average_peak_score(self.plan_size), + ], + evals=[(dtrain, 'tr')], + maximize=False, + verbose_eval=self.verbose_eval)]) + + def predict(self, task, states): + features = get_per_stmt_features_from_states(states, task) + if self.bst is not None and len(self.inputs) > self.num_warmup_sample: + dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) + raw_preds = self.bst.predict(dtest) + ret = pack_sum_predict_throughput(raw_preds, pack_ids) + else: + ret = np.random.uniform(0, 1, (len(states),)) + + # Predict 0 for invalid states that failed to be lowered. + for idx, feature in enumerate(features): + if feature.min() == feature.max() == 0: + ret[idx] = float('-inf') + + return ret + + def predict_stages(self, task, states): + # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) + + features = get_per_stmt_features_from_states(states, task) + if self.bst is not None and len(self.inputs) > self.num_warmup_sample: + dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) + raw_preds = self.bst.predict(dtest) + breakdown = pack_sum_predict_throughput(raw_preds, pack_ids) + stage_scores = [[] for _ in range(len(states))] + for pred, pack_id in zip(raw_preds, pack_ids): + stage_scores[pack_id].append(pred) + for idx, stage_score in enumerate(stage_scores): + breakdown = np.append(breakdown, len(stage_score)) + breakdown = np.concatenate((breakdown, -np.array(stage_score))) + else: + breakdown = np.concatenate( + (np.random.uniform(0, 1, (len(states), )), np.zeros(len(states), ))) + + # Predict 0 for invalid states that failed to be lowered. + for idx, feature in enumerate(features): + if feature.min() == feature.max() == 0: + breakdown[idx] = float('-inf') + + return breakdown + + def load_log_file(self, file_name, n_lines=-1): + inputs, results = LogReader(file_name).read_lines(n_lines) + logger.info("XGBModel: Loaded %s lines of history log from %s", len(inputs), file_name) + self.update(inputs, results) + + def save(self, file_name: str): + self.bst.save_model(file_name) + + def load(self, file_name: str): + if self.bst is None: + self.bst = xgb.Booster(self.xgb_params) + self.bst.load_model(file_name) + self.num_warmup_sample = -1 + + +def pack_sum_xgbmatrix_for_prediction(xs): + x_flatten = [] + pack_ids = [] + + for ct, x in enumerate(xs): + for row in x: + x_flatten.append(row) + pack_ids.append(ct) + + return xgb.DMatrix(x_flatten), pack_ids + + +def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): + if gids is not None: + # sort by group + indices = gids.argsort() + xs, ys = xs[indices], ys[indices] + group_sizes = np.bincount(gids) + if weights is not None: + weights = weights[indices] + else: + # assume it has only one group + group_sizes = [len(xs)] + + x_flatten = [] + y_flatten = [] + weights_flatten = [] + pack_ids = [] + + if weights is not None: + for ct, (x, y, w) in enumerate(zip(xs, ys, weights)): + for row in x: + x_flatten.append(row) + y_flatten.append(y) + weights_flatten.append(w) + pack_ids.append(ct) + else: + for ct, (x, y) in enumerate(zip(xs, ys)): + for row in x: + x_flatten.append(row) + y_flatten.append(y) + pack_ids.append(ct) + + ret = xgb.DMatrix(x_flatten, y_flatten) + if weights is not None: + ret.set_weight(weights_flatten) + dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) + dmatrix_context.put('group_sizes', ret, group_sizes) + return ret + +LOSS_TYPE = 3 + +# Type 0 +# The model predicts cost. Use square error of throughput as loss +# loss = 1/2 * (1 / sum(x_i) - y) ^ 2 +# +# Type 1 +# The model predicts cost. Use square error of cost as loss +# loss = 1/2 * (sum(x_i) - 1 / y) ^ 2 +# +# Type 2 +# The model predicts throughput. Use square error of throughput as loss. +# loss = 1/2 * (1 / sum(1 / x_i) - y) ^ 2 +# +# Type 3 +# The model predicts throughput. Use square error of throughput as loss. +# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) +# loss = 1/2 * (-sum(x_i) - y) ^ 2 +# +# Type 4 +# The model predicts throughput. Use square error of throughput as loss. +# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) +# Also add a sigmoid to force the prediction to be within the range of (0, 1) +# loss = 1/2 * (sigmoid(-sum(x_i)) - y) ^ 2 +# + +def pack_sum_predict_throughput(raw_preds, pack_ids): + if LOSS_TYPE == 0: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return 1 / sum_pred + elif LOSS_TYPE == 1: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return 1 / sum_pred + elif LOSS_TYPE == 2: + sum_inverse_preds = np.bincount(pack_ids, weights=1 / raw_preds) + return 1 / sum_inverse_preds + elif LOSS_TYPE == 3: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return - sum_pred # pylint: disable=invalid-unary-operand-type + elif LOSS_TYPE == 4: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return 1 / (1 + np.exp(sum_pred)) + else: + raise ValueError("Invalid loss type: " + LOSS_TYPE) + +def pack_sum_square_error(preds, dtrain): + pack_ids = dmatrix_context.get("pack_ids", dtrain) + weight = dtrain.get_weight() + + if LOSS_TYPE == 0: + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = (x * y - 1) / np.power(x, 3) + hessian = (3 - 2 * x * y) / np.power(x, 4) + elif LOSS_TYPE == 1: + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = x - 1 / np.minimum(y, 1e6) + hessian = np.ones_like(gradient) + elif LOSS_TYPE == 2: + sum_inverse_preds = np.bincount(pack_ids, weights=1 / preds)[pack_ids] + y = dtrain.get_label() + gradient = (1 / sum_inverse_preds - y) / (np.power(preds * sum_inverse_preds, 2)) + hessian = (2 * preds * y * np.power(sum_inverse_preds, 2) - 2 * y * sum_inverse_preds - 2 * preds * sum_inverse_preds + 3) / (np.power(preds * sum_inverse_preds, 4)) + elif LOSS_TYPE == 3: + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = x + y + hessian = np.ones_like(gradient) + elif LOSS_TYPE == 4: + sum_pred = np.bincount(pack_ids, weights=preds) + exp_x = np.exp(sum_pred[pack_ids]) + exp_2x = np.power(exp_x, 2) + y = dtrain.get_label() + gradient = exp_x * (exp_x * y + y - 1) / np.power(exp_x + 1, 3) + hessian = exp_x * (-exp_2x * y + 2 * exp_x + y - 1) / np.power(exp_x + 1, 4) + else: + raise ValueError("Invalid loss type: " + LOSS_TYPE) + + if len(weight) == 0: + return gradient, hessian + else: + return gradient * weight, hessian * weight + +def pack_sum_rmse(raw_preds, dtrain): + pack_ids = dmatrix_context.get("pack_ids", dtrain) + preds = pack_sum_predict_throughput(raw_preds, pack_ids)[pack_ids] + return 'p-rmse', np.sqrt(np.mean(np.square((preds - dtrain.get_label())))) + +def pack_sum_average_peak_score(N): + """Evaluate pack sum average peak score for xgb""" + + def feval(preds, labels): + group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) + pack_ids = dmatrix_context.get("pack_ids", labels) + + preds = pack_sum_predict_throughput(preds, pack_ids) + labels = (np.bincount(pack_ids, weights=labels.get_label()) + / np.unique(pack_ids, return_counts=True)[1]) + + scores = [] + offset = 0 + for size in group_sizes: + preds_group = preds[offset:offset + size] + labels_group = labels[offset:offset + size] + offset += size + + trials = np.argsort(preds_group)[::-1][:N] + trial_scores = labels_group[trials] + curve = max_curve(trial_scores) / np.max(labels_group) + scores.append(np.mean(curve)) + return "a-peak@%d" % N, np.mean(scores) + return feval + +def pack_sum_average_recall_score(N): + """evaluate average recall score for xgb""" + + def feval(preds, labels): + group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) + pack_ids = dmatrix_context.get("pack_ids", labels) + + preds = pack_sum_predict_throughput(preds, pack_ids) + labels = (np.bincount(pack_ids, weights=labels.get_label()) + / np.unique(pack_ids, return_counts=True)[1]) + + scores = [] + offset = 0 + for size in group_sizes: + preds_group = preds[offset:offset + size] + labels_group = labels[offset:offset + size] + offset += size + + trials = np.argsort(preds_group)[::-1] + ranks = get_rank(labels_group[trials])[:N] + curve = recall_curve(ranks) + scores.append(np.mean(curve)) + return "a-recall@%d" % N, np.mean(scores) + return feval + + +def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, + maximize=False, verbose_eval=True, skip_every=2): + """Callback function for xgboost to support multiple custom evaluation functions""" + from xgboost.core import EarlyStopException + from xgboost.callback import _fmt_metric + from xgboost.training import aggcv + + state = {} + metric_shortname = metric.split("-")[1] + + def init(env): + """internal function""" + bst = env.model + + state['maximize_score'] = maximize + state['best_iteration'] = 0 + if maximize: + state['best_score'] = float('-inf') + else: + state['best_score'] = float('inf') + + if bst is not None: + if bst.attr('best_score') is not None: + state['best_score'] = float(bst.attr('best_score')) + state['best_iteration'] = int(bst.attr('best_iteration')) + state['best_msg'] = bst.attr('best_msg') + else: + bst.set_attr(best_iteration=str(state['best_iteration'])) + bst.set_attr(best_score=str(state['best_score'])) + else: + assert env.cvfolds is not None + + def callback(env): + """internal function""" + if not state: + init(env) + + bst = env.model + i = env.iteration + cvfolds = env.cvfolds + + res_dict = {} + + if i % skip_every == 1: + return + + ##### evaluation ##### + if cvfolds is not None: + for feval in fevals: + tmp = aggcv([f.eval(i, feval) for f in cvfolds]) + for k, mean, std in tmp: + res_dict[k] = [mean, std] + else: + for feval in fevals: + bst_eval = bst.eval_set(evals, i, feval) + res = [x.split(':') for x in bst_eval.split()] + for kv in res[1:]: + res_dict[kv[0]] = [float(kv[1])] + + eval_res = [] + keys = list(res_dict.keys()) + keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) + for key in keys: + v = res_dict[key] + eval_res.append([key] + v) + + ##### print eval result ##### + if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: + infos = ["XGB iter: %3d" % i] + for item in eval_res: + if 'null' in item[0]: + continue + infos.append("%s: %.6f" % (item[0], item[1])) + + logger.debug("\t".join(infos)) + if log_file: + with open(log_file, "a") as fout: + fout.write("\t".join(infos) + '\n') + + ##### choose score and do early stopping ##### + score = None + for item in eval_res: + if item[0] == metric: + score = item[1] + break + assert score is not None + + best_score = state['best_score'] + best_iteration = state['best_iteration'] + maximize_score = state['maximize_score'] + if (maximize_score and score > best_score) or \ + (not maximize_score and score < best_score): + msg = '[%d] %s' % ( + env.iteration, + '\t'.join([_fmt_metric(x) for x in eval_res])) + state['best_msg'] = msg + state['best_score'] = score + state['best_iteration'] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr(best_score=str(state['best_score']), + best_iteration=str(state['best_iteration']), + best_msg=state['best_msg']) + elif env.iteration - best_iteration >= stopping_rounds: + best_msg = state['best_msg'] + if verbose_eval and env.rank == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + raise EarlyStopException(best_iteration) + + return callback diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index e10da09e4b5a..e35a73148f3a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -35,6 +35,8 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel @@ -190,6 +192,52 @@ def __init__(self, key, host, port, priority=1, "and make sure you have free devices on the queue status.") +class RPCRunnerWarpper: + def __init__(self, target=None, priority=1, + n_parallel=1, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.target = target + self.priority = priority + self.n_parallel = n_parallel + self.timeout = timeout + self.number = number + self.repeat = repeat + self.min_repeat_ms = min_repeat_ms + self.cooldown_interval = cooldown_interval + + self.tracker = None + self.server = None + self.runner = None + + def __enter__(self): + if self.target == "cuda": + ctx = tvm.context("cuda", 0) + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + host = '0.0.0.0' + self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % self.tracker.port + self.server = Server(host, port=self.tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(self.tracker.host, self.tracker.port)) + self.runner = RPCRunner(device_key, host, self.tracker.port, self.priority, + self.n_parallel, self.timeout, self.number, self.repeat, + self.min_repeat_ms, self.cooldown_interval) + + return self + + def __exit__(self, type, value, trace): + if value: + raise value + + self.tracker.terminate() + self.server.terminate() + MAX_ERROR_MSG_LEN = 512 diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index 8e0936071774..bbf15a241974 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -37,7 +37,7 @@ using ::tvm::runtime::NDArray; TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_OBJECT_TYPE(RandomModelNode); TVM_REGISTER_OBJECT_TYPE(MeasureModelNode); -TVM_REGISTER_OBJECT_TYPE(PythonBasedCostModelNode); +TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode); void RandomNumber(TVMArgs args, TVMRetValue* rv) { int n = args[0]; @@ -101,30 +101,30 @@ void MeasureModelNode::Predict(const SearchTask& task, } } -CostModel PythonBasedCostModelNode::make(PackedFunc update_func, - PackedFunc predict_func, - PackedFunc predict_stage_func) { - auto node = make_object(); +CostModel PythonBasedModelNode::make(PackedFunc update_func, + PackedFunc predict_func, + PackedFunc predict_stage_func) { + auto node = make_object(); node->update_func = std::move(update_func); node->predict_func = std::move(predict_func); node->predict_stage_func = std::move(predict_stage_func); return CostModel(node); } -void PythonBasedCostModelNode::Update(const Array& inputs, - const Array& results) { +void PythonBasedModelNode::Update(const Array& inputs, + const Array& results) { update_func(inputs, results); } -void PythonBasedCostModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { +void PythonBasedModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { scores->resize(states.size()); predict_func(task, Array(states.begin(), states.end()), static_cast(scores->data())); } -void PythonBasedCostModelNode::PredictStages( +void PythonBasedModelNode::PredictStages( const SearchTask& task, const std::vector& states, std::vector* state_scores, std::vector>* stage_scores) { @@ -188,5 +188,12 @@ TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { return RandomModelNode::make(); }); +TVM_REGISTER_GLOBAL("ansor.PythonBasedModel") +.set_body_typed([](PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func) { + return PythonBasedModelNode::make(update_func, predict_func, + predict_stage_func); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index 9daf01197bbf..472a3c201068 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -92,7 +92,7 @@ class MeasureModelNode : public CostModelNode { /*! \brief A wrapper for cost model defined by python code * This class will call python's function */ -class PythonBasedCostModelNode: public CostModelNode { +class PythonBasedModelNode: public CostModelNode { public: PackedFunc update_func; PackedFunc predict_func; @@ -108,8 +108,8 @@ class PythonBasedCostModelNode: public CostModelNode { std::vector* state_scores, std::vector>* stage_scores) final; - static constexpr const char *_type_key = "ansor.PythonBasedCostModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedCostModelNode, CostModelNode); + static constexpr const char *_type_key = "ansor.PythonBasedModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); }; } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 86a7eba1da3a..f086a8879abb 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -1397,24 +1397,27 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( int id = RandomChoose(prefix_sum_probs, &rand_gen_); if (dis(rand_gen_) < mutation_prob) { - const std::vector rule_prefix_sum_probs{0.9, 0.95, 1.0}; + const std::vector rule_prefix_sum_probs{0.9, 1.0}; int rule_id = RandomChoose(rule_prefix_sum_probs, &rand_gen_); - State tmp_s; if (rule_id == 0) { - tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, + // Mutate Tile Size + State tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, cur_task_->hardware_params->max_innermost_split_factor); + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } else { + mutation_fail_ct++; + } } else if (rule_id == 1) { - tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); - } else if (rule_id == 2) { - tmp_s = MutataParallel((*pnow)[id], &split_memo_, &rand_gen_, cur_task_); - } - - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } else { - mutation_fail_ct++; + // Mutate auto-unroll max step. + State tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } else { + mutation_fail_ct++; + } } } else { pnext->push_back((*pnow)[id]); diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 9a57691aba22..6636787e807f 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -24,28 +24,25 @@ import tvm from tvm import ansor -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server from test_ansor_common import matmul_nkkm -def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'): +def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', + cost_model=ansor.RandomModel(), n_trials=2): print("Test %s schedule search with the default search policy" % (target)) + random.seed(seed) N = 128 A, B, C = matmul_nkkm(N, N, N) dag = ansor.ComputeDAG([A, B, C]) tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) - random.seed(seed) - with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, runner=runner, + tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, tune_option=tune_option) @@ -83,48 +80,30 @@ def test_search_basic(): search_common(seed=944563397) +def test_search_xgb_model_rpc_runner(): + with ansor.RPCRunnerWarpper() as rpc_runner: + search_common(seed=456787236, cost_model=ansor.XGBModel(), + runner=rpc_runner.runner) + + def test_search_opencl(): if tvm.context("opencl", 0).exist: - host = '0.0.0.0' - tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server(host, port=tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) - - search_common("opencl", 380344973, rpc_runner) - - tracker.terminate() - server.terminate() + with ansor.RPCRunnerWarpper() as rpc_runner: + search_common("opencl", 380344973, rpc_runner.runner) else: print("OpenCL device not found, skip this test.") def test_search_cuda(): - ctx = tvm.context("cuda", 0) - if ctx.exist: - cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) - host = '0.0.0.0' - tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server(host, port=tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) - - search_common("cuda", 903667810, rpc_runner) - - tracker.terminate() - server.terminate() + if tvm.context("cuda", 0).exist: + with ansor.RPCRunnerWarpper("cuda") as rpc_runner: + search_common("cuda", 903667810, rpc_runner.runner) else: print("CUDA device not found, skip this test.") if __name__ == "__main__": test_search_basic() + test_search_xgb_model_rpc_runner() test_search_opencl() test_search_cuda() From cfe58d7829cd649f4b1a4af8f4af3200dbc5174f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Jun 2020 01:16:58 -0700 Subject: [PATCH 16/78] Migrate workload_registry.py (#16) * add workload registry * update * update --- python/tvm/ansor/__init__.py | 8 +- .../tvm/ansor/{task.py => auto_schedule.py} | 0 python/tvm/ansor/feature.py | 2 +- python/tvm/ansor/measure.py | 5 +- python/tvm/ansor/serialization.py | 5 + python/tvm/ansor/workload_registry.py | 190 ++++++++++++++++++ src/ansor/feature.cc | 2 + src/ansor/serialization.cc | 62 +++++- src/tir/analysis/verify_gpu_code.cc | 44 +++- tests/python/unittest/test_ansor_common.py | 11 +- tests/python/unittest/test_ansor_feature.py | 62 +++++- .../python/unittest/test_ansor_loop_state.py | 8 +- .../unittest/test_ansor_search_policy.py | 4 +- 13 files changed, 364 insertions(+), 39 deletions(-) rename python/tvm/ansor/{task.py => auto_schedule.py} (100%) create mode 100644 python/tvm/ansor/workload_registry.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 2d27995e328e..bb4822409757 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -21,15 +21,17 @@ from . import measure from . import serialization from . import loop_state -from . import task +from . import auto_schedule from . import utils from . import feature +from . import workload_registry # Shortcut from .compute_dag import ComputeDAG -from .task import SearchTask, MetaTileRewritePolicy, TuneOption -from .task import auto_schedule +from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams +from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file +from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/auto_schedule.py similarity index 100% rename from python/tvm/ansor/task.py rename to python/tvm/ansor/auto_schedule.py diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index fb5fadf16296..a0885aabdc20 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -24,7 +24,7 @@ import numpy as np from .loop_state import StateObject -from .task import SearchTask +from .auto_schedule import SearchTask from .measure import MeasureInput, MeasureResult from . import _ffi_api diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index e35a73148f3a..0209a717cf0e 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -44,6 +44,8 @@ logger = logging.getLogger('ansor') +MAX_ERROR_MSG_LEN = 512 + @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): @@ -238,8 +240,6 @@ def __exit__(self, type, value, trace): self.tracker.terminate() self.server.terminate() -MAX_ERROR_MSG_LEN = 512 - class MeasureErrorNo(object): """Error type for MeasureResult""" @@ -505,3 +505,4 @@ def timed_func(inp, build_res): print("") return measure_results + diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index bd9a69944057..387825034a09 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -69,6 +69,11 @@ def write_measure_records_to_file(filename, inputs, results): _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) +def get_states_from_measure_inputs(inputs, task): + """Get states from measure inputs""" + return _ffi_api.GetStatesFromMeasureInputs(inputs, task) + + def best_measure_pair_in_file(filename, workload_key=None, target=None): """ Return best results form log file diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py new file mode 100644 index 000000000000..c8b12f0244b2 --- /dev/null +++ b/python/tvm/ansor/workload_registry.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +from typing import List, Tuple, Callable, Union +from collections import Hashable +import pickle +import json +import hashlib + +import tvm._ffi +from ..te import Tensor, PlaceholderOp, ComputeOp, placeholder +from .utils import get_const_tuple +from .compute_dag import ComputeDAG + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_auto_scheduler_workload_func(func: Callable): + """Register a workload generation function + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Examples + -------- + @register_auto_scheduler_workload_func + def matmul(N, M, K): + A = tvm.placeholder((N, K), name='A') + B = tvm.placeholder((K, M), name='B') + k = tvm.reduce_axis((0, K), name='k') + C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + """ + func_name = func.__name__ + if func_name in WORKLOAD_FUNC_REGISTRY: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = func + return func + + +def compute_dag_hash(dag: ComputeDAG): + # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG + str_key = '' + for op in dag.ops: + t = op.output(0) + if isinstance(op, PlaceholderOp): + str_key += 'placeholder,' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + elif isinstance(op, ComputeOp): + str_key += str(t.op.body) + ',' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + else: + raise ValueError("Invalid op: " + op) + + str_key = str_key.encode(encoding='utf-8') + return hashlib.md5(str_key).hexdigest() + + +def register_auto_scheduler_workload_bufs(bufs: List[Tensor]) -> str: + """Directly register buffers of a workload and return the workload_key + The buffers can be looked up with workload_key_to_tensors by the workload_key + """ + dag = ComputeDAG(bufs) + key = compute_dag_hash(dag) + WORKLOAD_FUNC_REGISTRY[key] = bufs + return json.dumps((key,)) + + +def list_to_tuple(x: List) -> Tuple: + """Convert a list to a tuple recursively""" + assert isinstance(x, list) + return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) + + +def serialize_args(args: Tuple) -> Tuple: + """ + Serialize arguments of a function to a hashable and jsonable tuple. + Currently this is mainly used for tvm.tensor.Tensor + """ + ret = [] + for t in args: + if isinstance(t, Tensor): + t = ('TENSOR', get_const_tuple(t.shape), t.dtype) + elif isinstance(t, list): + t = list_to_tuple(t) + + assert isinstance(t, Hashable), str(t) + " is not hashable" + ret.append(t) + + return tuple(ret) + + +def deserialize_args(args: Tuple) -> List: + """The inverse function of :code:`serialize_args`""" + ret = [] + for t in args: + if isinstance(t, (tuple, list)) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + +@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors") +def workload_key_to_tensors(workload_key: str) -> List[Tensor]: + """Decode a workload key to the input/output tensors""" + workload = json.loads(workload_key) + name = workload[0] + lookup = WORKLOAD_FUNC_REGISTRY[name] + + if callable(lookup): + args = deserialize_args(workload[1:]) + return lookup(*args) + else: + return lookup + + +@ tvm._ffi.register_func("auto_scheduler.workload_key_to_dag") +def workload_key_to_dag(workload_key: str) -> ComputeDAG: + """Decode a workload key to a compute dag""" + tensors = workload_key_to_tensors(workload_key) + return ComputeDAG(tensors) + + +def make_workload_key_func(func: Union[str, Callable], args: Tuple) -> str: + """make a workload key from function and arguments""" + args = serialize_args(args) + + if callable(func): + func_name = func.__name__ + elif isinstance(func, str): + func_name = func + else: + raise ValueError("Invalid function: " + str(func)) + + assert func_name in WORKLOAD_FUNC_REGISTRY, \ + "%s is not registered. Please register it with register_auto_scheduler_workload_func" % func + + return json.dumps((func_name,) + args) + + +def make_workload_key_bufs(bufs: List[Tensor]) -> str: + """make a workload key from bufs""" + dag = ComputeDAG(bufs) + key = compute_dag_hash(dag) + return json.dumps((key,)) + + +def dump_workload_func_registry(filename: str): + """Dump workload function registry to a pickle binary file""" + global WORKLOAD_FUNC_REGISTRY + + pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb')) + + +def load_workload_func_registry(filename: str): + """Load workload function registry from a pickle binary file""" + global WORKLOAD_FUNC_REGISTRY + + WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) + diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 16ddb73ebf47..497a3ac4222b 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -1241,6 +1241,8 @@ void GetPerStmtFeaturesFromStates(const Array& states, for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], max_n_bufs, &(*features)[i], &error_ct); + //GetPerStmtFeaturesWorkerFunc(task, states[i], + // max_n_bufs, &(*features)[i], &error_ct); } pool.WaitBatch(); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 53c75a13f197..76f5d4449001 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -507,13 +507,7 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { // skip comment lines begin with '#' or ' ' continue; } - - try { - ReadMeasureRecord(cur_line, inp, res, &log_version); - } catch (...) { - return false; - } - + ReadMeasureRecord(cur_line, inp, res, &log_version); return true; } @@ -607,5 +601,59 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") } }); +TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array inputs = args[0]; + SearchTask external_task; + + if (args.size() > 1) { + external_task = args[1]; + } + + Array states; + states.reserve(inputs.size()); + + // (workload_key, target) -> (search_task) + std::unordered_map, SearchTask> task_cache; + + for (const auto& inp : inputs) { + const std::string& workload_key = inp->task->workload_key; + std::pair key(workload_key, inp->task->target->str()); + + const SearchTaskNode* ptask; + if (external_task.defined()) { + ptask = external_task.operator->(); + } else { + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inp->task->compute_dag.defined()) { // the measure input is complete + ptask = inp->task.operator->(); + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + SearchTask new_task = SearchTaskNode::make( + ComputeDAGNode::make_by_workload_key(workload_key), + workload_key, + inp->task->target, + inp->task->target_host, + inp->task->hardware_params); + task_cache.insert(std::make_pair(key, new_task)); + ptask = new_task.operator->(); + } + } else { + ptask = find_res->second.operator->(); + } + } + + State tmp_s = ptask->compute_dag.GetInitState(); + StateNode *ps = tmp_s.CopyOnWrite(); + ps->transform_steps = inp->state->transform_steps; + tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag); + states.push_back(std::move(tmp_s)); + } + + *ret = states; +}); + + } // namespace ansor } // namespace tvm diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 1fbae0fd2dcd..f6a8ad034aa5 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -33,20 +33,22 @@ namespace tvm { namespace tir { -class GPUCodeVerifier : public StmtVisitor { +class GPUCodeVerifier : public StmtExprVisitor { public: bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z) { + int64_t max_thread_z, int64_t max_vector_bytes) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); max_thread_x_ = static_cast(max_thread_x); max_thread_y_ = static_cast(max_thread_y); max_thread_z_ = static_cast(max_thread_z); + max_vector_bytes_ = static_cast(max_vector_bytes); Reset_(); + // TODO(jcf94): Add support of detecting CUDA Misaligned Address error this->VisitStmt(stmt); return valid_; @@ -62,6 +64,10 @@ class GPUCodeVerifier : public StmtVisitor { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } + + if (op->dtype.lanes() > 1) { + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); + } } void VisitStmt_(const AttrStmtNode* op) final { @@ -129,6 +135,18 @@ class GPUCodeVerifier : public StmtVisitor { } } + void VisitExpr_(const LoadNode* op) { + // Currently not able to check: + // if the index expression failed to be simplified to a Ramp + if (op->index->IsInstance()) { + if (op->dtype.lanes() > 1) { + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= + static_cast(max_vector_bytes_); + } + } + ExprVisitor::VisitExpr_(op); + } + private: int nest_level_{0}; @@ -146,6 +164,7 @@ class GPUCodeVerifier : public StmtVisitor { size_t max_shared_memory_per_block_; size_t max_threads_per_block_; size_t max_thread_x_, max_thread_y_, max_thread_z_; + size_t max_vector_bytes_; bool valid_{true}; @@ -169,27 +188,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { int64_t max_thread_x = INT64_MAX; int64_t max_thread_y = INT64_MAX; int64_t max_thread_z = INT64_MAX; + int64_t max_vector_bytes = INT64_MAX; for (auto iter : constraints) { const IntImmNode* val = iter.second.as(); - if (iter.first == "max_local_memory_per_block") + if (iter.first == "max_local_memory_per_block") { max_local_memory_per_block = val->value; - else if (iter.first == "max_shared_memory_per_block") + } else if (iter.first == "max_shared_memory_per_block") { max_shared_memory_per_block = val->value; - else if (iter.first == "max_threads_per_block") + } else if (iter.first == "max_threads_per_block") { max_threads_per_block = val->value; - else if (iter.first == "max_thread_x") + } else if (iter.first == "max_thread_x") { max_thread_x = val->value; - else if (iter.first == "max_thread_y") + } else if (iter.first == "max_thread_y") { max_thread_y = val->value; - else if (iter.first == "max_thread_z") + } else if (iter.first == "max_thread_z") { max_thread_z = val->value; - else + } else if (iter.first == "max_vector_bytes") { + max_vector_bytes = val->value; + } else { LOG(FATAL) << "Invalid check item: " << iter.first; + } } return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, - max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, + max_vector_bytes); } TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index cd8a1eedb162..1790b06bcb60 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -17,18 +17,16 @@ """Common functions for ansor test cases""" - from tvm import te, ansor import topi -def matmul_nkkm(N, M, K): +@ansor.register_auto_scheduler_workload_func +def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') - C = te.compute((N, M), lambda i, j: te.sum( - A[i][k] * B[k][j], axis=[k]), name='C') - + C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] @@ -58,7 +56,7 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation def get_tiled_matmul(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s0 = dag.get_init_state() A, B, C = 0, 1, 2 @@ -80,3 +78,4 @@ def get_tiled_matmul(): C_global += 1 s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) return dag, s0.state_object + diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index abd304a9c2d7..3da1c7aa332e 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Test feature extraction""" import math @@ -6,7 +23,7 @@ import tvm from tvm import te, ansor -from test_ansor_common import matmul_nkkm +from test_ansor_common import matmul_ansor_test def fequal(a, b): @@ -14,7 +31,7 @@ def fequal(a, b): def test_cpu_matmul(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() C = 2 @@ -87,11 +104,48 @@ def fusion_test(N, M): def test_gpu_feature(): - # todo(lmzheng) - pass + ctx = tvm.context("cuda", 0) + if not ctx.exist: + return + + json_records = "\n".join(( + """{"i": [["[\\"matmul_ansor_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""", + )) + + # load states + with tempfile.NamedTemporaryFile(mode='w') as f: + f.write(json_records) + f.flush() + inputs, results = ansor.LogReader(f.name).read_lines() + + inp = inputs[0] + dag = ansor.workload_key_to_dag(inp.task.workload_key) + task = ansor.SearchTask(dag, inp.task.workload_key, inp.task.target, None, ansor.HardwareParams(100000, 16, 64, 4, 64)) + + state = ansor.serialization.get_states_from_measure_inputs(inputs, task)[0] + state = dag.infer_bound_from_state(state) + fea = ansor.feature.get_per_stmt_features_from_states([state], task)[0] + names = ansor.feature.get_per_stmt_feature_names() + + # build feature dict + fea_dicts = [] + for i in range(len(fea)): + tmp_dict = {} + for j in range(len(names)): + tmp_dict[names[j]] = fea[i][j] + fea_dicts.append(tmp_dict) + + # check values + assert fequal(fea_dicts[0]['blockIdx_x_len'], math.log2(8 + 1)) + assert fequal(fea_dicts[0]['vthread_len'], math.log2(4 + 1)) + assert fequal(fea_dicts[1]['threadIdx_x_len'], math.log2(16 + 1)) + assert fequal(fea_dicts[0]['threadIdx_y_len'], math.log2(1 + 1)) + assert fequal(fea_dicts[2]['blockIdx_z_len'], math.log2(1 + 1)) + assert fequal(fea_dicts[0]['is_gpu'], 1.0) if __name__ == "__main__": test_cpu_matmul() test_cpu_fusion() test_gpu_feature() + diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 287a1b773395..612d320036d8 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -20,11 +20,11 @@ from tvm import ansor, te import topi -from test_ansor_common import matmul_nkkm, conv2d_nchw_bn_relu +from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu def test_split_fuse_reorder_annotation(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s0 = dag.get_init_state() C = 2 i, j, k = s0.stages[C].iters @@ -67,7 +67,7 @@ def test_split_fuse_reorder_annotation(): def test_follow_split_follow_fused_split(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s0 = dag.get_init_state() C = 2 @@ -433,7 +433,7 @@ def test_cache_read_write(): def test_rfactor(): - dag = ansor.ComputeDAG(matmul_nkkm(8, 8, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(8, 8, 512)) s0 = dag.get_init_state() C = 2 diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 6636787e807f..a28456574abe 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -25,7 +25,7 @@ import tvm from tvm import ansor -from test_ansor_common import matmul_nkkm +from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', cost_model=ansor.RandomModel(), n_trials=2): @@ -33,7 +33,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 - A, B, C = matmul_nkkm(N, N, N) + A, B, C = matmul_ansor_test(N, N, N) dag = ansor.ComputeDAG([A, B, C]) tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) From 143ea451bfe848ed7ef5424ffaf468344c38ea4c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Jun 2020 02:10:48 -0700 Subject: [PATCH 17/78] add task scheduler (#17) --- python/tvm/ansor/__init__.py | 2 + python/tvm/ansor/auto_schedule.py | 7 +- python/tvm/ansor/cost_model/__init__.py | 1 + python/tvm/ansor/cost_model/xgb_model.py | 10 +- python/tvm/ansor/feature.py | 5 +- python/tvm/ansor/measure.py | 7 + python/tvm/ansor/task_scheduler.py | 274 ++++++++++++++++++ python/tvm/ansor/workload_registry.py | 1 - src/ansor/measure.cc | 39 +-- src/ansor/search_policy/search_policy.cc | 17 ++ .../unittest/test_ansor_task_scheduler.py | 43 +++ 11 files changed, 365 insertions(+), 41 deletions(-) create mode 100644 python/tvm/ansor/task_scheduler.py create mode 100644 tests/python/unittest/test_ansor_task_scheduler.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index bb4822409757..4e57c16d18a5 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -25,6 +25,7 @@ from . import utils from . import feature from . import workload_registry +from . import task_scheduler # Shortcut from .compute_dag import ComputeDAG @@ -35,3 +36,4 @@ from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag +from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index affcf4a6e195..5f4b7946b087 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,7 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel +from .cost_model import RandomModel, XGBModel from . import _ffi_api @@ -67,11 +67,12 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): - pass + def continue_search(self, task, num_measure, verbose, measurer): + return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") -class MetaTileRewritePolicy(Object): +class MetaTileRewritePolicy(SearchPolicy): """ The search policy that searches with meta tiling and random rewrite Parameters diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index fc3821cf7998..56e4a5f9128b 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -18,3 +18,4 @@ """ Cost model that estimates the performance of programs """ from .cost_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py index e61acfbd168f..fce3f16d18ba 100644 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -92,14 +92,15 @@ def update(self, inputs, results): # extract feature n_cached = len(self.inputs_feature_cache) features, normalized_throughputs, task_ids = \ - get_per_stmt_features_from_measure_pairs(self.inputs, self.results, - skip_first_n_feature_extraction=n_cached) + get_per_stmt_features_from_measure_pairs(self.inputs, self.results, + skip_first_n_feature_extraction=n_cached) if n_cached > 0: features = list(features) features[:n_cached] = self.inputs_feature_cache features = np.array(features) self.inputs_feature_cache = features - dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, task_ids, normalized_throughputs) + dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, + task_ids, normalized_throughputs) # train xgb model self.bst = xgb.train(self.xgb_params, dtrain, @@ -133,7 +134,6 @@ def predict(self, task, states): def predict_stages(self, task, states): # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) - features = get_per_stmt_features_from_states(states, task) if self.bst is not None and len(self.inputs) > self.num_warmup_sample: dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) @@ -339,7 +339,7 @@ def feval(preds, labels): return feval def pack_sum_average_recall_score(N): - """evaluate average recall score for xgb""" + """Evaluate average recall score for xgb""" def feval(preds, labels): group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index a0885aabdc20..f91d7da169f5 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -24,7 +24,6 @@ import numpy as np from .loop_state import StateObject -from .auto_schedule import SearchTask from .measure import MeasureInput, MeasureResult from . import _ffi_api @@ -124,7 +123,7 @@ def get_per_stmt_features_from_file(filename: str, def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], results: List[MeasureResult], skip_first_n_feature_extraction: int = 0, - max_n_bufs: int = None,) \ + max_n_bufs: int = None) \ -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Get per_stmt features from measurement pairs""" byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( @@ -133,7 +132,7 @@ def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], def get_per_stmt_features_from_states(states: List[StateObject], - task: SearchTask, + task: "SearchTask", max_n_bufs: int = None) -> List[np.ndarray]: """Get per_stmt features from states""" byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 0209a717cf0e..b062eb585d12 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -171,6 +171,13 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("ansor.ProgramMeasurer") +class ProgramMeasurer(Object): + def __init__(self, builder: Builder, runner: Runner, + callbacks: List[MeasureCallback], + verbose: int, max_continuous_error: int = -1): + self.__init_handle_by_constructor__( + _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error) @tvm._ffi.register_object("ansor.RPCRunner") class RPCRunner(Runner): diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py new file mode 100644 index 000000000000..5144591d4f98 --- /dev/null +++ b/python/tvm/ansor/task_scheduler.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""TaskScheduler that allocates the time resources when tuning multiple tasks together""" +from typing import List, Union, Callable +import time + +import numpy as np + +from .auto_schedule import SearchTask, SearchPolicy, MetaTileRewritePolicy, TuneOption +from .cost_model import RandomModel, XGBModel +from .measure import ProgramMeasurer +from .utils import array_mean, to_str_round + + +class TaskScheduler: + """Allocate the time resources when tuning multiple tasks together""" + def __init__(self, + tasks: List[SearchTask], + objective_func: Callable = None): + self.tasks = tasks + self.objective_func = objective_func or sum + + def compute_score(self, costs: List[float]) -> float: + return self.objective_func(costs) + + +def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], + num_measure_per_iter, load_model_file=None, load_log_file=None): + if search_policy == 'default': + search_policy = 'meta-rewrite.xgb' + + if isinstance(search_policy, str): + policy_type, model_type = search_policy.split('.') + if model_type == 'xgb': + cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measure_per_iter) + if load_model_file: + print("Load pretrained model...") + cost_model.load(load_model_file) + elif load_log_file: + cost_model.load_log_file(load_log_file) + elif model_type == 'random': + cost_model = RandomModel() + else: + raise ValueError("Invalid search policy: " + search_policy) + + if policy_type == 'meta-rewrite': + search_policies = [MetaTileRewritePolicy(cost_model) for _ in range(len(tasks))] + elif policy_type == 'limit-space': + search_policies = [MetaTileRewritePolicy(cost_model, + params={'cpu_multi_level_tiling_structure': 'SRS', + 'disable_change_compute_location': 1}) + for _ in range(len(tasks))] + elif policy_type == 'beam-search': + search_policies = [MetaTileRewritePolicy(cost_model, + params={'use_beam_search': 1}) + for _ in range(len(tasks))] + else: + raise ValueError("Invalid search policy: " + search_policy) + else: + # check type + assert isinstance(search_policy, (tuple, list)) + for item in search_policy: + assert isinstance(item, SearchPolicy) + search_policies = search_policy + + return search_policies + + +class SimpleTaskScheduler(TaskScheduler): + """The default task scheduler with several strategies + + Parameters + ---------- + tasks: List[SearchTask] + All workloads to tune + weights: List[float] + Weights of tasks (i.e. the number of occurrence of a task in the whole network) + strategy: str + The joint tuning strategy. + "sequential" : Tune tasks sequentially. Divide n_trials equally to every task. + "round-robin": Tune tasks in round robin order. + "gradient" : Tune tasks with gradient descent. + load_log_file: str + Load history log file to pre-train cost model + eps-random: float + Always allocate this percent of n_trials to select tasks randomly. This is for encouraging exploration. + verbose: int + The level of verbosity. 0 means silent. + alpha: float + The parameter used for 'gradient' strategy + beta: float + The parameter used for 'gradient' strategy + backward_window_size: int + The parameter used for 'gradient' strategy + """ + def __init__(self, + tasks: List[SearchTask], + objective_func: Callable = None, + strategy: str = 'gradient', + load_log_file: str = None, + load_model_file: str = None, + eps_random: float = 0.05, + verbose: int = 1, + alpha: float = 0.2, + beta: float = 2, + gamma: float = 0.5, + backward_window_size: int = 3, + use_debug_measurement_simulator=None): + super().__init__(tasks, objective_func) + self.strategy = strategy + self.eps_random = eps_random + self.verbose = verbose + self.load_log_file = load_log_file + self.load_model_file = load_model_file + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.backward_window_size = backward_window_size + self.use_debug_measurement_simulator = use_debug_measurement_simulator + + assert self.strategy in ['round-robin', 'gradient'] + + self.task_cts = [] + self.task_costs_history = [] + self.best_costs = self.cur_score = None + self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None + self.num_measure_per_iter = None + self.dead_tasks = set() + self.sequential_now_task_idx = 0 + self.sequential_now_task_begin_ct = 0 + + def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + # init members + self.task_cts = [0 for _ in range(len(self.tasks))] + self.task_costs_history = [[] for _ in range(len(self.tasks))] + self.best_costs = 1e10 * np.ones(len(self.tasks)) + self.cur_score = self.compute_score(self.best_costs) + self.tune_option = tune_option + if self.use_debug_measurement_simulator is None: + self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, + tune_option.callbacks, tune_option.verbose) + self.ct = 0 + self.tic = time.time() + # reset num_measure_per_iter to make sure every task is tuned at least once + self.num_measure_per_iter = min(tune_option.num_measure_per_iter, + tune_option.n_trials // len(self.tasks)) + self.search_policies = get_search_policies(search_policy, self.tasks, + self.num_measure_per_iter, + self.load_model_file, + self.load_log_file) + self.dead_tasks = set() + self.sequential_now_task_idx = 0 + self.sequential_now_task_begin_ct = 0 + + # do a round robin first + if self.strategy != 'sequential': + for i in range(len(self.tasks)): + self.tune_task(i) + + # use the specific strategy to choose workload to tune + task_idx = -1 + while self.ct < tune_option.n_trials and len(self.dead_tasks) < len(self.tasks): + if self.strategy == 'sequential': + allocated_total_ct = ((tune_option.n_trials - self.sequential_now_task_begin_ct) + / (len(self.tasks) - self.sequential_now_task_idx)) + used_ct = self.ct - self.sequential_now_task_begin_ct + + if self.sequential_now_task_idx in self.dead_tasks or used_ct >= allocated_total_ct: + self.sequential_now_task_idx += 1 + self.sequential_now_task_begin_ct = self.ct + task_idx = self.sequential_now_task_idx + if task_idx >= len(self.tasks): + break + elif self.strategy == 'round-robin': + task_idx = (task_idx + 1) % len(self.tasks) + while task_idx in self.dead_tasks: + task_idx = (task_idx + 1) % len(self.tasks) + elif self.strategy == 'gradient': + gradients = [] + for i in range(len(self.tasks)): + if i in self.dead_tasks: + gradients.append(0) + continue + + # compute gradient from chain rule : (delta f / delta g_i) + delta = 1e-7 + new_costs = list(self.best_costs) + new_costs[i] -= delta + chain_grad = (self.compute_score(self.best_costs) - self.compute_score(new_costs)) / delta + + # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t) + if self.task_cts[i] - 1 - self.backward_window_size >= 0: + backward_grad = (self.task_costs_history[i][self.task_cts[i] - 1] + - self.task_costs_history[i][self.task_cts[i] - 1 - self.backward_window_size]) \ + / self.backward_window_size + else: + backward_grad = 0 + + # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t) + g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i]) + # todo(lmzheng): this needs adding attribute to topi.compute for similarity check + g_next_2 = self.beta * 1e20 + g_next = min(g_next_1, g_next_2) + forward_grad = g_next - self.best_costs[i] + + # combine all grads + grad = chain_grad * (self.alpha * backward_grad + (1 - self.alpha) * forward_grad) + assert grad <= 0 + gradients.append(grad) + + if max(gradients) == min(gradients): + task_idx = np.random.choice(len(gradients)) + else: + task_idx = np.argmin(gradients) + else: + raise ValueError("Invalid strategy: " + self.strategy) + + self.tune_task(task_idx) + + def tune_task(self, task_idx): + if self.use_debug_measurement_simulator is not None: + measure_inputs, measure_results = \ + self.use_debug_measurement_simulator.get_next_batch( + self.tasks[task_idx], + self.num_measure_per_iter, + ) + else: + measure_inputs, measure_results = \ + self.search_policies[task_idx].continue_search( + self.tasks[task_idx], + self.num_measure_per_iter, + self.tune_option.verbose, + self.measurer) + + for inp, res in zip(measure_inputs, measure_results): + cost = array_mean(res.costs) + if cost < self.best_costs[task_idx]: + self.best_costs[task_idx] = cost + + if len(measure_inputs) == 0: + self.dead_tasks.add(task_idx) + + self.task_cts[task_idx] += 1 + self.task_costs_history[task_idx].append(self.best_costs[task_idx]) + + self.ct += len(measure_inputs) + self.cur_score = self.compute_score(self.best_costs) + + if self.verbose >= 1: + print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + + "best_costs (ms): %s\ttask_ct: %s") % + (self.ct, self.cur_score * 1e3, time.time() - self.tic, + to_str_round(self.best_costs * 1e3, decimal=3), + self.task_cts)) + + def remove_dead_task(self, prob): + for idx in self.dead_tasks: + prob[idx] = 0 + return prob / prob.sum() diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index c8b12f0244b2..381e6009eea8 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -187,4 +187,3 @@ def load_workload_func_registry(filename: str): global WORKLOAD_FUNC_REGISTRY WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) - diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index e3593753d3ff..73bbade241c5 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -324,24 +324,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed([](SearchTask task, State state) { - return MeasureInputNode::make(task, state); -}); +.set_body_typed(MeasureInputNode::make); TVM_REGISTER_GLOBAL("ansor.BuildResult") -.set_body_typed([](std::string filename, Array args, - int error_no, std::string error_msg, double time_cost) { - return BuildResultNode::make(filename, args, error_no, error_msg, - time_cost); -}); +.set_body_typed(BuildResultNode::make); TVM_REGISTER_GLOBAL("ansor.MeasureResult") -.set_body_typed([](Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { - return MeasureResultNode::make(costs, error_no, error_msg, all_cost, - timestamp); -}); +.set_body_typed(MeasureResultNode::make); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") .set_body_typed([](const Builder& builder, @@ -356,25 +345,17 @@ TVM_REGISTER_GLOBAL("ansor.RunnerRun") }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") -.set_body_typed([](int timeout, int n_parallel, - const std::string& build_func) { - return LocalBuilderNode::make(timeout, n_parallel, build_func); -}); +.set_body_typed(LocalBuilderNode::make); TVM_REGISTER_GLOBAL("ansor.LocalRunner") -.set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, - double cooldown_interval) { - return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, - cooldown_interval); -}); +.set_body_typed(LocalRunnerNode::make); TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed([](const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { - return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval); -}); +.set_body_typed(RPCRunnerNode::make); + +TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") +.set_body_typed(ProgramMeasurerNode::make); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 866922d0001e..f3072fda4956 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,11 +23,28 @@ */ #include "search_policy.h" +#include namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); +// Search Policy +TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") +.set_body([](TVMArgs args, TVMRetValue *ret) { + SearchPolicy policy = args[0]; + SearchTask task = args[1]; + int num_measure = args[2]; + int verbose = args[3]; + ProgramMeasurer measurer = args[4]; + + Array inputs; + Array results; + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); + + *ret = Array{inputs, results}; +}); + } // namespace ansor } // namespace tvm diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py new file mode 100644 index 000000000000..e95d65d4b5ce --- /dev/null +++ b/tests/python/unittest/test_ansor_task_scheduler.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test the task scheduler """ + +import tvm +from tvm import ansor + +from test_ansor_common import matmul_ansor_test + +def test_task_scheduler_basic(): + N = 128 + A, B, C = matmul_ansor_test(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) + tgt = tvm.target.create("llvm") + task1 = ansor.SearchTask(dag, "test", tgt) + task2 = ansor.SearchTask(dag, "test", tgt) + + def objective(costs): + return sum(costs) + + task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) + tune_option = ansor.TuneOption(n_trials=3, runner='local') + + task_scheduler.tune(tune_option) + + +if __name__ == "__main__": + test_task_scheduler_basic() From ed075c276c3fecc3ed3ff16b87a707b5482ff6f9 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 9 Jun 2020 17:53:06 +0800 Subject: [PATCH 18/78] Add conv2d cuda tutorial with workload registry (#18) --- docs/conf.py | 2 +- python/tvm/ansor/__init__.py | 3 +- tutorials/ansor/tune_conv2d_cuda.py | 164 ++++++++++++++++++++++++ tutorials/ansor/tune_simple_subgraph.py | 2 + 4 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 tutorials/ansor/tune_conv2d_cuda.py diff --git a/docs/conf.py b/docs/conf.py index 5cbaab7f7b6d..5826526d55b0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -197,8 +197,8 @@ ['../tutorials/frontend', '../tutorials/language', '../tutorials/optimize', - '../tutorials/ansor', '../tutorials/autotvm', + '../tutorials/ansor', '../tutorials/dev', '../tutorials/topi', '../tutorials/deployment', diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 4e57c16d18a5..bfdbaf9c8c8c 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -35,5 +35,6 @@ from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file -from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag +from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag, \ + make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py new file mode 100644 index 000000000000..82a5e8572ba2 --- /dev/null +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-scheduling High Performance Convolution on NVIDIA GPUs +=========================================================== +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_, \ + `Minmin Sun `_, \ + `Zhao Wu `_ + +This is an tutorial for searching high performance schedule for NVIDIA GPU using +Ansor auto-scheduler. By running Ansor on this template, we can outperform the +vendor provided library CuDNN in many cases. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use autotvm package in tvm, we need to install some extra dependencies. +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user psutil xgboost tornado +# +# To make TVM run faster in tuning, it is recommended to use cython +# as FFI of tvm. In the root directory of tvm, execute +# +# .. code-block:: bash +# +# pip3 install --user cython +# sudo make cython3 +# +# Now return to python code. Import packages. + +import random +import sys + +import numpy as np +import tvm +import topi +from topi.testing import conv2d_nchw_python +from tvm import te + +# the module is called `ansor` +from tvm import ansor + +###################################################################### +# Step 1: Define the search task +# ------------------------------- +# There are plenty of useful schedule primitives in tvm. You can also find +# some tutorials that describe them in more details, such as +# (1). :ref:`opt-conv-gpu` +# (2). `Optimizing DepthwiseConv on NVIDIA GPU `_ +# +# It's usually a hard job if one wants to get a high performance schedule for a +# specific workload. Even writing an AutoTVM tunable template needs user to have +# expertises on how each schedule primitive works as well as how they finally +# reflect on the hardward architecture. +# +# However, with Ansor this will be quite simple. Firstly, define the target workload. +# Both :code:`tvm.te` API or topi op API are fine to be used. +# +# We can use the retuned :code:`Tensors` to create a ComputeDAG just like what we do +# in :ref:`ansor-simple-subgraph`, while the way to use workload registry is more +# recommended. + +# Use an extra function decorator to regist this workload +@ansor.register_auto_scheduler_workload_func +def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, KH, KW), name='kernel') + conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32') + + return [data, kernel, conv] + +###################################################################### +# Step 2: Search through the schedule space +# ------------------------------------------ +# We pick the last layer on resnet as test case. +# Since our space is very large, :code:`XGBModel` is most suitable +# for our case. Here we only do 20 trials for demonstration. +# In practice, making 1000 trials usually can find some good kernels +# for this workload. + +tgt = tvm.target.cuda() + +# The last layer in resnet +N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) +# Generate workload key with the ansor API +wkl_key = ansor.make_workload_key_func(conv2d_nchw, (N, H, W, CO, CI, KH, KW, strides, padding)) +# Generate ComputeDAG using the workload key +dag = ansor.workload_key_to_dag(wkl_key) +task = ansor.SearchTask(dag, wkl_key, target=tgt) + +log_file = "conv2d_nchw.json" +seed = 0 +random.seed(seed) +cost_model = ansor.XGBModel() +search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + +######################################################################### +# The :code:`ansor.RPCRunnerWarpper` is used to create a RPC runner environment, +# +# Use local gpu, measure 10 times for every schedule to reduce variance. The timeout +# for each running is set to 4 seconds. +# +# During the searching process, we may generate several invalid schedules and they +# will be filtered out. It's fine to see "Encountered errors during feature extraction." +# in the tuning logs. + +with ansor.RPCRunnerWarpper("cuda", repeat=3, min_repeat_ms=100, timeout=4) as rpc_runner: + tune_option = ansor.TuneOption(n_trials=20, + runner=rpc_runner.runner, + callbacks=[ansor.LogToFile(log_file)]) + state = ansor.auto_schedule(task, search_policy, + tune_option=tune_option) + print(state) + +######################################################################### +# Finally we can directly use the returned result to get the generated schedule, +# while in the following tutorial we'll show how to inspect the best config from +# log file, check correctness, and measure running time. + +# Get history best from log file +inp, res = ansor.best_measure_pair_in_file(log_file) +# Get the task ComputeDAG from log result +dag = ansor.workload_key_to_dag(inp.task.workload_key) +# Apply log result to TVM schedule +s, arg_bufs = dag.apply_steps_from_state(inp.state) +func = tvm.build(s, arg_bufs, target=tgt) + +# check correctness +a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32) +w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32) +c_np = conv2d_nchw_python(a_np, w_np, strides, padding) + +ctx = tvm.gpu() +a_tvm = tvm.nd.array(a_np, ctx=ctx) +w_tvm = tvm.nd.array(w_np, ctx=ctx) +c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx) +func(a_tvm, w_tvm, c_tvm) + +tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2) + +# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise +# and the overhead of kernel launch. You can also use nvprof to validate the result. +evaluator = func.time_evaluator(func.entry_name, ctx, number=400) +print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean) + diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 8555d6163c32..2af33c1e88ba 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """ +.. _ansor-simple-subgraph: + Writing compute expression and Using Ansor auto-scheduler ========================================================= **Author**: `Lianmin Zheng `_, \ From 74ec7d0b792c31d04993d4b0e4ae1ea912e4e792 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Jun 2020 04:40:01 -0700 Subject: [PATCH 19/78] add tune_test.py (the old tune_wkl.py) (#19) * add tune_test.py (the old tune_wkl.py) * update * fix measure * fix for gpu --- .gitignore | 3 + python/tvm/ansor/__init__.py | 8 +- python/tvm/ansor/auto_schedule.py | 21 +- python/tvm/ansor/measure.py | 52 +- python/tvm/ansor/workload_registry.py | 4 +- scripts/common.py | 1017 +++++++++++++++++ scripts/tune_test.py | 195 ++++ src/ansor/auto_schedule.cc | 16 +- src/ansor/auto_schedule.h | 4 +- tests/python/unittest/test_ansor_measure.py | 17 +- .../unittest/test_ansor_search_policy.py | 51 +- 11 files changed, 1285 insertions(+), 103 deletions(-) create mode 100644 scripts/common.py create mode 100644 scripts/tune_test.py diff --git a/.gitignore b/.gitignore index 506e54d93067..3c03e8ecda7a 100644 --- a/.gitignore +++ b/.gitignore @@ -234,3 +234,6 @@ conda/pkg # antlr files *.tokens *.interp + +# ansor tuning logs +scripts/*.json diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index bfdbaf9c8c8c..2e3553cf725c 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -31,10 +31,10 @@ from .compute_dag import ComputeDAG from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams from .auto_schedule import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel -from .serialization import LogToFile, LogReader, best_measure_pair_in_file -from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag, \ - make_workload_key_func +from .serialization import LogToFile, LogReader, best_measure_pair_in_file, write_measure_records_to_file +from .workload_registry import register_auto_scheduler_workload_func, \ + workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 5f4b7946b087..1192e6d551e5 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -160,12 +160,12 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose, builder, runner, callbacks) -def auto_schedule(workload, search_policy='default', target=None, - target_host=None, hardware_params=None, - tune_option=None): +def auto_schedule(workload, target=None, + target_host=None, search_policy='default', + hardware_params=None, tune_option=None): """ Do auto schedule for a compute declaration. - The workload paramter can be a `string` as workload_key, or directly + The workload parameter can be a `string` as workload_key, or directly passing a `SearchTask` as input. Parameters @@ -174,8 +174,6 @@ def auto_schedule(workload, search_policy='default', target=None, target : Target - task : SearchTask - target_host : Target = None search_policy : Union[SearchPolicy, str] @@ -203,13 +201,12 @@ def auto_schedule(workload, search_policy='default', target=None, if isinstance(workload, str): sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( - workload, target, target_host, search_policy, hardware_params, - tune_option) + workload, target, target_host, search_policy, hardware_params, tune_option) return sch, tensors elif isinstance(workload, SearchTask): - state = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, - tune_option) - return state + sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) + return sch, tensors else: raise ValueError("Invalid workload: " + workload + - ", should be String or SearchTask") + ". Expect a string or SearchTask") + diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b062eb585d12..299c004f756d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -37,6 +37,7 @@ from tvm.ir import transform from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server +from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel @@ -78,7 +79,7 @@ class BuildResult(Object): def __init__(self, filename, args, error_no, error_msg, time_cost): self.__init_handle_by_constructor__( - _ffi_api.BuildResult, filename, args, error_no, + _ffi_api.BuildResult, filename if filename else "", args, error_no, error_msg if error_msg else "", time_cost) @@ -201,49 +202,32 @@ def __init__(self, key, host, port, priority=1, "and make sure you have free devices on the queue status.") -class RPCRunnerWarpper: - def __init__(self, target=None, priority=1, +class LocalRPCMeasureContext: + def __init__(self, + priority=1, n_parallel=1, timeout=10, - number=3, + number=10, repeat=1, min_repeat_ms=0, cooldown_interval=0.0): - self.target = target - self.priority = priority - self.n_parallel = n_parallel - self.timeout = timeout - self.number = number - self.repeat = repeat - self.min_repeat_ms = min_repeat_ms - self.cooldown_interval = cooldown_interval - - self.tracker = None - self.server = None - self.runner = None - - def __enter__(self): - if self.target == "cuda": - ctx = tvm.context("cuda", 0) + ctx = tvm.context("cuda", 0) + if ctx.exist: cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + set_cuda_target_arch(cuda_arch) host = '0.0.0.0' self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) device_key = '$local$device$%d' % self.tracker.port self.server = Server(host, port=self.tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(self.tracker.host, self.tracker.port)) - self.runner = RPCRunner(device_key, host, self.tracker.port, self.priority, - self.n_parallel, self.timeout, self.number, self.repeat, - self.min_repeat_ms, self.cooldown_interval) - - return self - - def __exit__(self, type, value, trace): - if value: - raise value - + key=device_key, use_popen=True, silent=True, + tracker_addr=(self.tracker.host, self.tracker.port)) + self.runner = RPCRunner(device_key, host, self.tracker.port, priority, + n_parallel, timeout, number, repeat, + min_repeat_ms, cooldown_interval) + # wait for the processes to start + time.sleep(0.5) + + def __del__(self): self.tracker.terminate() self.server.terminate() diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 381e6009eea8..fccdcf8864be 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -130,7 +130,7 @@ def deserialize_args(args: Tuple) -> List: return ret -@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors") +@tvm._ffi.register_func("ansor.workload_key_to_tensors") def workload_key_to_tensors(workload_key: str) -> List[Tensor]: """Decode a workload key to the input/output tensors""" workload = json.loads(workload_key) @@ -144,7 +144,7 @@ def workload_key_to_tensors(workload_key: str) -> List[Tensor]: return lookup -@ tvm._ffi.register_func("auto_scheduler.workload_key_to_dag") +@ tvm._ffi.register_func("ansor.workload_key_to_dag") def workload_key_to_dag(workload_key: str) -> ComputeDAG: """Decode a workload key to a compute dag""" tensors = workload_key_to_tensors(workload_key) diff --git a/scripts/common.py b/scripts/common.py new file mode 100644 index 000000000000..4400104bdfe6 --- /dev/null +++ b/scripts/common.py @@ -0,0 +1,1017 @@ +"""Common utility for scripts""" +import argparse +import math +import os +import re +import time +from collections import defaultdict, namedtuple +from typing import Dict, List, Tuple + +import numpy as np +import matplotlib.pyplot as plt + +import topi +import tvm +from tvm import te +from tvm.ansor import (LogReader, make_workload_key_func, + register_auto_scheduler_workload_func, + write_measure_records_to_file) +from tvm.contrib import ndk, util + +############################################################ +###################### Test Workloads #################### +############################################################ + +@register_auto_scheduler_workload_func +def min_mn(M, N): + A = te.placeholder((M, N), name='A') + B = topi.min(A, axis=1) + + return [A, B] + +@register_auto_scheduler_workload_func +def argmin_mn(M, N): + A = te.placeholder((M, N), name='A') + B = topi.argmin(A, axis=1) + + return [A, B] + +@register_auto_scheduler_workload_func +def softmax_mn(M, N): + A = te.placeholder((M, N), name='A') + B = topi.nn.softmax(A, axis=1) + + return [A, B] + +@register_auto_scheduler_workload_func +def norm_bmn(B, M, N): + A = te.placeholder((B, M, N), name='A') + i = te.reduce_axis((0, M)) + j = te.reduce_axis((0, N)) + C = te.compute((B,), lambda b: te.sum(A[b][i][j] * A[b][i][j], axis=[i, j]), name='C') + D = te.compute((B,), lambda b: te.sqrt(C[b]), name='D') + + return [A, D] + +@register_auto_scheduler_workload_func +def add_mn(M, N): + A = te.placeholder((M, N), name='A') + B = te.placeholder((M, N), name='B') + C = te.compute((M, N), lambda i, j: A[i][j] + B[i][j], name='C') + + return [A, B, C] + +@register_auto_scheduler_workload_func +def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', + tensor_core_support=False): + A = te.placeholder((N, K), name='A', dtype=in_type) + B = te.placeholder((K, M), name='B', dtype=in_type) + k = te.reduce_axis((0, K), name='k') + if in_type == out_type: + if not (in_type == 'float16' and out_type == 'float16'): + tensor_core_support = False + C = te.compute((N, M), + lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), + name='C', + attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + else: + if not ((in_type == 'float16' and out_type == 'float32') or \ + (in_type == 'int8' and out_type == 'int32')): + tensor_core_support = False + C = te.compute((N, M), + lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), + axis=[k]), + name='C', + attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + + return [A, B, C] + +@register_auto_scheduler_workload_func +def dense_layer(batch, in_dim, out_dim): + A = te.placeholder((batch, in_dim), name='A') + B = te.placeholder((out_dim, in_dim), name='B') + k = te.reduce_axis((0, in_dim), name='k') + C = te.compute((batch, out_dim), lambda i, j: te.sum(A[i][k] * B[j][k], axis=[k]), name='C') + + return [A, B, C] + +@register_auto_scheduler_workload_func +def max_pool_2d_nchw(N, C, H, W): + data = te.placeholder((N, C, H, W), name='data') + out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, + layout="NCHW", count_include_pad=True) + + return [data, out] + +@register_auto_scheduler_workload_func +def add_min_relu(M, N): + A = te.placeholder((M, N), name='A') + B = te.placeholder((M, N), name='B') + C = topi.add(A, B) + D = topi.min(C, axis=1) + out = topi.nn.relu(D) + return [A, B, out] + +@register_auto_scheduler_workload_func +def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, KH, KW), name='kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + relu = topi.nn.relu(conv) + softmax = topi.nn.softmax(relu, axis=1) + out = topi.min(softmax, axis=1) + + return [data, kernel, out] + +@register_auto_scheduler_workload_func +def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, KH, KW), name='kernel') + bias = te.placeholder((CO, 1, 1), name='bias') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + #out = topi.nn.relu(conv) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + +def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, out_dtype='float32'): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape + if len(Filter.shape) == 10: + kernel_h = Filter.shape[2] * Filter.shape[6] + kernel_w = Filter.shape[3] * Filter.shape[7] + channel = Filter.shape[4] * Filter.shape[8] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] + #Filter = te.placeholder([kernel_h, kernel_w, channel, num_filter], Filter.dtype, Filter.name) + elif len(Filter.shape) == 11: + kernel_h = Filter.shape[3] * Filter.shape[7] + kernel_w = Filter.shape[4] * Filter.shape[8] + channel = Filter.shape[5] * Filter.shape[9] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] + else: + kernel_h, kernel_w, channel, num_filter = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.util.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = topi.util.simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = topi.util.simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name='rc') + ry = te.reduce_axis((0, kernel_h), name='ry') + rx = te.reduce_axis((0, kernel_w), name='rx') + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + Filter[ry, rx, rc, ff].astype(out_dtype) + , axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc") + return Output + + +@register_auto_scheduler_workload_func +def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((KH, KW, CI, CO), name='kernel') + bias = te.placeholder((CO, ), name='bias') + conv = topi.nn.conv2d_nhwc(data, kernel, strides, padding, dilation) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + +@register_auto_scheduler_workload_func +def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((KH, KW, CI, 1), name='kernel') + bias = te.placeholder((CO, ), name='bias') + conv = topi.nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + +@register_auto_scheduler_workload_func +def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((KH, KW, CI, CO), name='kernel') + bias = te.placeholder((CO, ), name='bias') + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + + +@register_auto_scheduler_workload_func +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') + bias = te.placeholder((CO, 1, 1), name='bias') + bn_scale = te.placeholder((CO, 1, 1), name='bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + +@register_auto_scheduler_workload_func +def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') + bias = te.placeholder((CO,), name='bias') + bn_scale = te.placeholder((CO,), name='bn_scale') + bn_offset = te.placeholder((CO,), name='bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute((N, OH, OW, CO), + lambda i, j, k, l: conv[i, j, k, l] + bias[l], + name='bias_add') + conv = te.compute((N, OH, OW, CO), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], + name='bn_mul') + conv = te.compute((N, OH, OW, CO), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], + name='bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + +resnet_conv2d_configs = { + # format : N, H, W, CI, CO, KH, KW, strides, padding, dilation + '18': [ + (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), + (1, 56, 56, 64, 128, 3, 3, (2, 2), (1, 1), (1, 1)), + (1, 56, 56, 64, 128, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 128, 256, 3, 3, (2, 2), (1, 1), (1, 1)), + (1, 28, 28, 128, 256, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 14, 14, 256, 512, 3, 3, (2, 2), (1, 1), (1, 1)), + (1, 14, 14, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), + ], + '50': [ + (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), + (1, 56, 56, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 56, 56, 256, 128, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 56, 56, 256, 64, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 56, 56, 64, 256, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 512, 1024, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 28, 28, 512, 256, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 28, 28, 512, 128, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 128, 512, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 14, 14, 1024, 2048, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 14, 14, 1024, 512, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 14, 14, 1024, 256, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 14, 14, 256, 1024, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 7, 7, 2048, 512, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 7, 7, 512, 2048, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), + ], +} + +# number of appearance for all conv2ds in resnet +resnet_conv2d_weights = { + '18': [1, 1, 1, 4, 1, 1, 1, 3, 1, 1, 3, 3], + '50': [1, 1, 1, 2, 4, 3, 1, 1, 1, 3, 4, 4, 1, 1, 5, 6, 6, 2, 3, 3], +} + + +def parse_workload_name(name: str) -> List[str]: + """Parse workload name with wildcard character and abbreviation to standard names""" + if name.startswith('matmul-'): # e.g. matmul-512, matmul-1024, matmul-+ + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [256, 512, 1024] + else: + cfg_list = [N] + return ["matmul-%s" % x for x in cfg_list] + elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = ["1-512-512", "16-512-512"] + else: + cfg_list = [N] + return ["dense-%s" % x for x in cfg_list] + elif name.startswith('min-'): # e.g. min-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["min-%s" % x for x in cfg_list] + elif name.startswith('argmin-'): # e.g. argmin-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["argmin-%s" % x for x in cfg_list] + elif name.startswith('softmax-'): # e.g. softmax-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["softmax-%s" % x for x in cfg_list] + elif name.startswith('add-'): # e.g. add-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["add-%s" % x for x in cfg_list] + elif name.startswith('norm-'): # e.g. norm-1024 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["norm-%s" % x for x in cfg_list] + elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 + N = name.split('-', maxsplit=3)[3] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["add-min-relu-%s" % x for x in cfg_list] + elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1 + res = re.match(r'nhwc-resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) + n_layers = res.group(1) + if res.group(2) == '+': + idx_list = range(len(resnet_conv2d_configs[n_layers])) + else: + idx_list = [int(res.group(2))] + + batch_size = 1 if res.group(4) is None else int(res.group(4)) + return ['nhwc-resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] + elif name.startswith('resnet-'): # e.g. resnet-50.C1, resnet-50.C1.B2, resnet-50.C+.B2 + res = re.match(r'resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) + n_layers = res.group(1) + if res.group(2) == '+': + idx_list = range(len(resnet_conv2d_configs[n_layers])) + else: + idx_list = [int(res.group(2))] + + batch_size = 1 if res.group(4) is None else int(res.group(4)) + return ['resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] + elif name in ['conv2d-bn-relu', 'conv2d-relu-softmax-min', 'max-pool-2d', 'conv2d-rewrite', 'depthwise-conv2d-rewrite']: + return [name] + else: + raise ValueError("Invalid workload " + name) + + +def get_workload_keys(name: str) -> List[str]: + """Parse workload name and return the workload keys""" + normalized_names = parse_workload_name(name) + + ret = [] + for name in normalized_names: + if name.startswith('matmul-'): + name_split = name.split('-') + in_type = out_type = 'float32' + tensor_core_support = False + if len(name_split) == 2: # e.g. matmul-512 + N = K = M = int(name_split[1]) + elif len(name_split) == 4: # e.g. matmul-32-256-512 + N = int(name_split[1]) + K = int(name_split[2]) + M = int(name_split[3]) + elif len(name_split) == 6: # e.g. matmul-32-512-512-float16-float32 + N = int(name_split[1]) + K = int(name_split[2]) + M = int(name_split[3]) + in_type = name_split[4] + out_type = name_split[5] + elif len(name_split) == 7: # e.g. matmul-32-512-512-float16-float32-tc + N = int(name_split[1]) + K = int(name_split[2]) + M = int(name_split[3]) + in_type = name_split[4] + out_type = name_split[5] + tensor_core_support = name_split[6] == "tc" + else: + raise ValueError("Invalid matmul workload") + ret.append(make_workload_key_func(matmul_nkkm, + (N, M, K, in_type, out_type, tensor_core_support))) + elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 + name_split = name.split('-') + assert len(name_split) == 4 + batch = int(name_split[1]) + in_dim = int(name_split[2]) + out_dim = int(name_split[3]) + ret.append(make_workload_key_func(dense_layer, (batch, in_dim, out_dim))) + elif name.startswith('min-'): # e.g. min-4096 + name_split = name.split('-') + if len(name_split) == 2: + M = 64 + N = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid min workload") + ret.append(make_workload_key_func(min_mn, (M, N))) + elif name.startswith('argmin-'): # e.g. argmin-4096 + name_split = name.split('-') + if len(name_split) == 2: + M = 64 + N = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid argmin workload") + ret.append(make_workload_key_func(argmin_mn, (M, N))) + elif name.startswith('softmax-'): # e.g. softmax-4096 + name_split = name.split('-') + if len(name_split) == 2: + M = 64 + N = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid softmax workload") + ret.append(make_workload_key_func(softmax_mn, (M, N))) + elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 + name_split = name.split('-') + if len(name_split) == 4: + M = 64 + N = int(name_split[3]) + elif len(name_split) == 5: + M = int(name_split[3]) + N = int(name_split[4]) + else: + raise ValueError("Invalid workload") + ret.append(make_workload_key_func(add_min_relu, (M, N))) + elif name.startswith('add-'): # e.g. add-4096 + name_split = name.split('-') + if len(name_split) == 2: + N = M = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid add workload") + ret.append(make_workload_key_func(add_mn, (M, N))) + elif name.startswith('norm-'): # e.g. norm-4096 + name_split = name.split('-') + B = 2 + if len(name_split) == 2: + N = M = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid norm workload") + ret.append(make_workload_key_func(norm_bmn, (B, M, N))) + elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1.B2 + res = re.match(r'nhwc-resnet-(\d+).C(\d+).B(\d+)', name) + n_layers = res.group(1) + idx = int(res.group(2)) + batch_size = 1 if res.group(3) is None else int(res.group(3)) + args = list(resnet_conv2d_configs[n_layers][idx]) + args[0] = batch_size + ret.append(make_workload_key_func(conv2d_nhwc_bias, args)) + elif name.startswith('resnet-'): # e.g. resnet-50.C1.B2 + res = re.match(r'resnet-(\d+).C(\d+).B(\d+)', name) + n_layers = res.group(1) + idx = int(res.group(2)) + batch_size = 1 if res.group(3) is None else int(res.group(3)) + args = list(resnet_conv2d_configs[n_layers][idx]) + args[0] = batch_size + ret.append(make_workload_key_func(conv2d_nchw_bias, args)) + elif name == 'max-pool-2d': + return [make_workload_key_func(max_pool_2d_nchw, (2, 512, 7, 7))] + elif name == 'conv2d-bn-relu': + return [make_workload_key_func(conv2d_nhwc_bn_relu, + (1, 7, 7, 512, 512, 3, 1, 1, 1)) ] + elif name == 'conv2d-rewrite': + return [ make_workload_key_func(conv2d_nhwc_bias_with_rewrite, + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] + elif name == 'depthwise-conv2d-rewrite': + return [ make_workload_key_func(depthwise_conv2d_nhwc_bias_with_rewrite, + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] + elif name == 'conv2d-relu-softmax-min': + return [make_workload_key_func(conv2d_relu_softmax_min, + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] + else: + raise ValueError("Invalid workload " + name) + + return ret + + +def get_workload_weights(name: str) -> List[float]: + """Return weights for workload name""" + if name.startswith('resnet-'): + res = re.match(r'resnet-(\d+).C+', name) + n_layers = res.group(1) + return np.array(resnet_conv2d_weights[n_layers]) + else: + return np.ones(len(get_workload_keys(name))) + + +############################################################ +###################### Measure Tools #################### +############################################################ + + +def measure_schedule(s, + bufs, + target, + target_host=None, + remote=None, + ndk_cc=None, + number=10, + repeat=3, + min_repeat_ms=500): + """Measure the time cost of a schedule""" + func = tvm.build(s, bufs, target=target, target_host=target_host) + if remote: + ctx = remote.context(str(target), 0) + temp = util.tempdir() + remote_path = temp.relpath("tmp_deploy_lib.so") + os.environ['TVM_NDK_CC'] = ndk_cc + func.export_library(remote_path, ndk.create_shared) + remote.upload(remote_path) + func = remote.load_module("tmp_deploy_lib.so") + else: + ctx = tvm.context(str(target), 0) + + if os.environ.get('TVM_AUTO_CACHE_FLUSH', '0') == '1': + min_repeat_ms = 0 + number = 1 + + time_f = func.time_evaluator(func.entry_name, + ctx, + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms) + + np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] + args = [tvm.nd.array(x, ctx=ctx) for x in np_args] + ctx.sync() + + costs = time_f(*args).results + + return costs + +def check_correctness(s, bufs, s_ref, buf_ref, target, target_host=None, remote=None, ndk_cc=None): + """Check the correctness of a schedule against a reference schedule""" + func = tvm.build(s, bufs, target=target, target_host=target_host) + func_ref = tvm.build(s_ref, buf_ref, target='llvm') + + if remote: + raise NotImplemented + else: + ctx = tvm.context(str(target), 0) + ctx_ref = tvm.cpu() + + np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] + args = [tvm.nd.array(x, ctx=ctx) for x in np_args] + args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] + ctx.sync() + + func(*args) + func_ref(*args_ref) + + for arr, arr_ref in zip(args, args_ref): + np.testing.assert_allclose(arr.asnumpy(), arr_ref.asnumpy()) + + +############################################################ +##################### Other Utilities #################### +############################################################ + + +def geomean(xs): + """Compute geometric mean""" + return math.exp(math.fsum(math.log(x) for x in xs) / len(xs)) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +global last_tic +last_tic = None + + +def PRINT_TIME(msg): + """Print time interval between differnt calls. This is for debug so we make the name letters capital""" + global last_tic + now = time.time() + + if last_tic is None: + last_tic = now + + print(msg, now - last_tic) + last_tic = now + + +############################################################ +###################### I/O Utilities ##################### +############################################################ + +# The format for a line in resulst file +BenchmarkRecord = namedtuple("BenchmarkRecord", [ + 'device', 'backend', 'workload_type', 'workload_name', 'library', 'algorithm', 'value', + 'time_stamp' +]) + + +class BaselineDatabase: + """A class for query records in baseline database""" + def __init__(self, filename): + self.filename = filename + + self.lines = [] + for line in open(filename): + if line.startswith('#') or line.isspace(): + continue + self.lines.append(line.split('\t')) + + def filter_records(self, devices=None, backends=None, wkl_names=None, libraries=None): + ret = [] + for line in self.lines: + line = BenchmarkRecord(*line) + + if devices is not None and line.device not in devices: + continue + if backends is not None and line.backend not in backends: + continue + if wkl_names is not None and line.workload_name not in wkl_names: + continue + if libraries is not None and line.library not in libraries: + continue + + ret.append(line) + return ret + + def get_data_dict(self, device, target, wkl_names) -> Tuple[Dict, List]: + """Return a data dict s.t. data[wkl][library] = cost""" + data = defaultdict(lambda: defaultdict(lambda: 1e10)) + + all_libraries = set() + + if "cpu" in target.keys: + backends = ['cpu'] + elif "gpu" in target.keys: + backends = ['gpu'] + else: + raise ValueError("Invalid target: " + target) + + # Read costs for baselines + records = self.filter_records(devices=[device], backends=backends, wkl_names=wkl_names) + for record in records: + # use min over (possible) multiple algorithms + all_libraries.add(record.library) + data[record.workload_name][record.library] = \ + min(data[record.workload_name][record.library], + np.mean(eval(record.value)['costs'])) + + return data, list(all_libraries) + + +class LogFileDatabase: + """A class for indexing best records in a log file""" + def __init__(self, filename: str, n_lines: int = -1): + inputs, results = LogReader(filename).read_lines(n_lines) + + # best records, search by (target_key, workload_key). e.g. ('gpu', 'conv2d...') + self.best_by_targetkey = {} + + # best according to (model, workload_key). e.g. ('1080ti', 'conv2d...')) + self.best_by_model = {} + + # find best records and build the index + for inp, res in zip(inputs, results): + if res.error_no != 0: + continue + + # use target keys in tvm target system as key to build best map + for target_key in inp.task.target.keys: + key = (target_key, inp.task.workload_key) + if key not in self.best_by_targetkey: + self.best_by_targetkey[key] = (inp, res) + else: + _, other_res = self.best_by_targetkey[key] + if np.mean([x.value for x in other_res.costs]) > \ + np.mean([x.value for x in res.costs]): + self.best_by_targetkey[key] = (inp, res) + + # use model as key to build best map + key = (inp.task.target.model, inp.task.workload_key) + if key not in self.best_by_model: + if inp.task.target.model != 'unknown': + self.best_by_model[key] = (inp, res) + else: + _, other_res = self.best_by_model[key] + if np.mean([x.value for x in other_res.costs]) > \ + np.mean([x.value for x in res.costs]): + self.best_by_model[key] = (inp, res) + + def write_best(self, filename: str): + best_records = list(self.best_by_targetkey.values()) + inputs = [x[0] for x in best_records] + results = [x[1] for x in best_records] + write_measure_records_to_file(filename, inputs, results) + + +############################################################ +###################### Plot Utilities #################### +############################################################ + +def max_curve(raw_curve): + """Return b[i] = max(a[:i]) """ + ret = [] + cur_max = -np.inf + for x in raw_curve: + cur_max = max(cur_max, x) + ret.append(cur_max) + return ret + +def min_curve(raw_curve): + """Return b[i] = min(a[:i]) """ + ret = [] + cur_min = np.inf + for x in raw_curve: + cur_min = min(cur_min, x) + ret.append(cur_min) + return ret + +def mean_curve(raw_curve, window_size=None): + """Return b[i] = mean(a[:i]) """ + ret = [] + mean = 0 + if window_size is None: + for i, x in enumerate(raw_curve): + mean = (mean * i + x) / (i + 1) + ret.append(mean) + else: + for i, x in enumerate(raw_curve): + if i >= window_size: + mean = (mean * window_size + x - raw_curve[i - window_size]) / window_size + else: + mean = (mean * i + x) / (i + 1) + ret.append(mean) + return ret + + +def enhance_color(color, h=1, l=1, s=1): + """Make color looks better for pyplot""" + import matplotlib.colors as mc + import colorsys + try: + c = mc.cnames[color] + except: + c = color + c = np.array(colorsys.rgb_to_hls(*mc.to_rgb(c))) + + h, l, s = h * c[0], l * c[1], s * c[2] + h, l, s = [max(min(x, 1), 0) for x in [h, l, s]] + + return colorsys.hls_to_rgb(h, l, s) + + +method_color_dict = { + 'ours': 'C0', + 'AutoTVM': 'C1', + + 'tensorflow': 'C2', + 'tensorflow-tensorrt': 'C9', + 'tflite': 'C2', + + 'pytorch': enhance_color('C3', l=1.1, s=0.9), + + 'FlexTensor': enhance_color('C5'), + 'halide': enhance_color('teal', l=1.25), + + 'Limit space': 'C7', + 'No fine-tuning': 'C8', + 'No task scheduler': 'C1', +} + +def method2color(method): + if '-batch-' in method: + method, batch_size = method.split('-batch-') + #return enhance_color(method_color_dict[method], s=1.1, l=1.5) + return method_color_dict[method] + else: + return method_color_dict[method] + +method_order_list = [ + 'pytorch', 'tensorflow', 'tensorflow-xla', 'tensorflow-tensorrt', + 'tflite', 'halide', 'FlexTensor', 'AutoTVM', + + 'Limit space', 'No fine-tuning', + 'ours', +] + +def method2order(method): + if '-batch-' in method: + method, batch_size = method.split('-batch-') + batch_size = int(batch_size) + return method_order_list.index(method) + batch_size / 100 + else: + return method_order_list.index(method) + +show_name_replace_dict = { + 'pytorch': "PyTorch", + 'tensorflow-tensorrt': 'TensorRT-TF', + 'tensorflow': 'TensorFlow', + 'tflite': 'TensorFlow Lite', + 'halide': 'Halide', + + 'ours': 'Ansor (ours)', + 'batch-16': 'batch', + + 'resnet_50': 'ResNet-50', + 'mobilenet_v2': 'Mobilenet V2', + 'resnet_18_3d': '3D-ResNet', + 'dcgan': 'DCGAN', + 'dqn': 'DQN', + 'bert': 'BERT', +} + +def show_name(name): + # if name.startswith('resnet-'): + # return name.split('.')[1] + for key, value in show_name_replace_dict.items(): + name = name.replace(key, value) + + return name + +def draw_grouped_bar_chart(data, baseline='pytorch', output='out.png', + yscale_log=False, yticks=None, y_max=None, + legend_bbox_to_anchor=None, legend_nrow=None, + figure_size=None, figax=None, draw_ylabel=True, draw_legend=True): + width = 1 + gap = 1.5 + fontsize = 19 + xticks_font_size = fontsize - 2 + + figure_size = figure_size or (11, 4) + legend_bbox_to_anchor = legend_bbox_to_anchor or (0.45, 1.35) + + all_methods = set() + legend_set = {} + + if figax is None: + fig, ax = plt.subplots() + axes = [] + axes.append(ax) + else: + ax = figax + + x0 = 0 + xticks = [] + xlabels = [] + + workloads = list(data.keys()) + for wkl in workloads: + ys = [] + colors = [] + + methods = list(data[wkl].keys()) + + if baseline in data[wkl]: + baseline_cost = data[wkl][baseline] + else: + # normalize to best library + baseline_cost = 1e10 + for method in methods: + if data[wkl][method] < baseline_cost: + baseline_cost = data[wkl][method] + + methods.sort(key=lambda x: method2order(x)) + for method in methods: + relative_speedup = baseline_cost / data[wkl][method] + if yticks is None: + ys.append(relative_speedup) + else: + ys.append(max(relative_speedup, yticks[0] * 1.1)) + colors.append(method2color(method)) + + # draw the bars + xs = np.arange(x0, x0 + len(ys)) + bars = ax.bar(xs, ys, width=width, color=colors) + + for method, bar_obj in zip(methods, bars): + all_methods.add(method) + if method not in legend_set: + legend_set[method] = bar_obj + + # tick and label + x0 += len(ys) + gap + + xticks.append(x0 - gap - len(ys)*width/2.0 - width/2.0) + xlabels.append(show_name(wkl)) + + ax.set_xticks(xticks) + ax.set_xticklabels(xlabels, fontsize=xticks_font_size) + plt.tick_params(axis='x', which='both', bottom='off', top='off') + + if draw_ylabel is True: + ax.set_ylabel('Relative Speedup', fontsize=fontsize) + elif isinstance(draw_ylabel, str): + ax.set_ylabel(draw_ylabel, fontsize=fontsize) + + if yscale_log: + ax.set_yscale('log', basey=2) + if yticks is not None: + ax.set_yticks(yticks) + if y_max: + ax.set_ylim(top=y_max) + + from matplotlib.ticker import FormatStrFormatter + ax.set_yticklabels(ax.get_yticks(), fontsize=fontsize) + ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) + ax.yaxis.grid(linewidth=0.4, linestyle='dotted') # draw grid line + ax.set_axisbelow(True) # grid lines are behind the rest + ax.tick_params(bottom=False, top=False, right=False) + + # put legend outside the plot + all_methods = list(all_methods) + all_methods.sort(key=lambda x : method2order(x)) + + if draw_legend: + legend_nrow = legend_nrow or 2 + ncol = (len(all_methods) + legend_nrow - 1)// legend_nrow + ax.legend([legend_set[x] for x in all_methods], + [show_name(x) for x in all_methods], + fontsize=fontsize-1, + loc='upper center', + bbox_to_anchor=legend_bbox_to_anchor, + ncol=ncol, + handlelength=1.0, + handletextpad=0.5, + columnspacing=1.1) + + if figax is None: + fig.set_size_inches(figure_size) + fig.savefig(output, bbox_inches='tight') + print("Output the plot to %s" % output) + + +def to_str_round(x, decimal=6): + if isinstance(x, str): + return x + if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): + return "[" + ", ".join([to_str_round(y, decimal=decimal) + for y in x]) + "]" + if isinstance(x, dict): + return str({k: eval(to_str_round(v)) for k, v in x.items()}) + if isinstance(x, int): + return str(x) + if isinstance(x, float): + format_str = "%%.%df" % decimal + return format_str % x + raise ValueError("Invalid value: " + str(x)) + diff --git a/scripts/tune_test.py b/scripts/tune_test.py new file mode 100644 index 000000000000..68f9dfadb8d4 --- /dev/null +++ b/scripts/tune_test.py @@ -0,0 +1,195 @@ +"""Use auto scheduler to tune workloads""" +import argparse +import logging +import os +import random + +import numpy as np + +import tvm +from tvm import ansor +from tvm.ansor.utils import request_remote + +from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool + + +def make_cost_model(model_type, load_model_file, load_log_file): + if model_type == 'xgb': + model = ansor.XGBModel() + if load_model_file: + print("Load pretrained model...") + model.load(load_model_file) + elif load_log_file: + model.load_log_file(load_log_file) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) + return model + + +def tune_workload(wkl_key, target, target_host, n_trials, num_measure_per_iter, + policy, log_file, verbose, + model_type, load_model_file, load_log_file, + build_timeout, local_measure=True, device_key=None, host="0.0.0.0", + port=9190, n_parallel=1, ndk_cc=None, remeasure=True): + """Tune a workload""" + + if False: + # Debug info. Print static analysis results from the access analyzer + dag = auto_scheduler.workload_key_to_dag(wkl_key) + print(dag.access_analyzer) + exit() + + model = make_cost_model(model_type, load_model_file, load_log_file) + + if policy == 'meta-rewrite': + policy = ansor.MetaTileRewritePolicy(program_cost_model=model) + elif policy == 'beam-search': + policy = ansor.MetaTileRewritePolicy(program_cost_model=model, + params={'use_beam_search': 1}) + else: + raise ValueError("Invalid search policy: " + policy) + + if local_measure: + builder = ansor.LocalBuilder(build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = ansor.LocalBuilder(build_timeout, build_func='ndk') + runner = ansor.RPCRunner(device_key, host=host, port=port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel) + + tune_option = ansor.TuneOption(n_trials=n_trials, + num_measure_per_iter=num_measure_per_iter, + verbose=verbose, + builder=builder, + runner=runner, + callbacks=[ansor.LogToFile(log_file)]) + s, bufs = ansor.auto_schedule(wkl_key, + target=target, target_host=target_host, + search_policy=policy, + tune_option=tune_option) + + if remeasure: + print("Found schedule:") + print(tvm.lower(s, bufs, simple_mode=True)) + print("Redo measurement for double check...") + if local_measure: + remote = None + else: + remote = request_remote(device_key, host, port, 1) + cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % + (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) + + +def tune_workloads_jointly(wkl_keys, weights, joint_tuner, target, target_host, + n_trials, num_measure_per_iter, + search_policy, log_file, verbose, + model_type, load_model_file, load_log_file, + build_timeout, local_measure=True, device_key=None, + host="0.0.0.0", port=9190, n_parallel=1, ndk_cc=None): + """Tune for multiple workloads jointly""" + if local_measure: + builder = ansor.LocalBuilder(timeout=build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = ansor.LocalBuilder(build_func='ndk', timeout=build_timeout) + runner = ansor.RPCRunner(device_key, host=host, port=port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel) + + tasks = [] + for wkl_key in wkl_keys: + dag = ansor.workload_key_to_dag(wkl_key) + tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) + + def objective_func(costs): + return sum(c * w for c, w in zip(costs, weights)) + + tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=joint_tuner, + load_log_file=load_log_file, load_model_file=load_model_file) + + search_policy = "%s.%s" % (search_policy, model_type) + tune_option = ansor.TuneOption(n_trials=n_trials, + num_measure_per_iter=num_measure_per_iter, + builder=builder, + verbose=verbose, + runner=runner, + callbacks=[ansor.LogToFile(log_file)]) + tuner.tune(tune_option, search_policy) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--wkl", type=str, required=True) + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') + parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') + parser.add_argument("--load-model", type=str) + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--task-scheduler", type=str, default='no', + choices=['no', 'gradient', 'round-robin'], + help='The strategy of task scheduler') + parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--device-key", type=str, default=None) + parser.add_argument("--host", type=str, default='0.0.0.0') + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--n-parallel", type=int, default=1) + parser.add_argument("--ndk-cc", type=str, default=None) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + logging.basicConfig() + logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + + log_file = args.log_file or args.wkl + ".json" + load_log_file = args.load_log or log_file + + target = tvm.target.create(args.target) + wkl_keys = get_workload_keys(args.wkl) + weights = get_workload_weights(args.wkl) + if args.task_scheduler == 'no': + # tune workloads one by one + for wkl_key in wkl_keys: + tune_workload(wkl_key, target, args.target_host, args.n_trials, + args.num_measure_per_iter, + args.policy, log_file, args.verbose, + args.model_type, args.load_model, load_log_file, + args.build_timeout, + args.local_measure, args.device_key, args.host, + args.port, args.n_parallel, args.ndk_cc, + remeasure=len(wkl_keys) == 1) + else: + # tune workloads jointly using JointTuner + tune_workloads_jointly(wkl_keys, weights, args.joint_tuner, + target, args.target_host, + args.n_trials, args.num_measure_per_iter, + args.policy, log_file, args.verbose, + args.model_type, args.load_model, args.load_log, + args.build_timeout, + args.local_measure, args.device_key, args.host, + args.port, args.n_parallel, args.ndk_cc) + diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index a0fa18874a69..3c793e5957f5 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -48,16 +48,18 @@ TuneOption TuneOptionNode::make(int n_trials, int early_stopping, return TuneOption(node); } -State AutoSchedule(SearchTask task, SearchPolicy search_policy, +std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { // Search for the best schedule ProgramMeasurer measurer = ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, tune_option->callbacks, tune_option->verbose); - return search_policy->Search( + State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, tune_option->num_measure_per_iter, tune_option->verbose, measurer); + + return task->compute_dag.ApplySteps(state->transform_steps); } std::pair > AutoSchedule( @@ -68,10 +70,8 @@ std::pair > AutoSchedule( SearchTask task = SearchTaskNode::make( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); - State state = AutoSchedule(std::move(task), std::move(search_policy), + return AutoSchedule(std::move(task), std::move(search_policy), std::move(tune_option)); - - return task->compute_dag.ApplySteps(state->transform_steps); } TVM_REGISTER_GLOBAL("ansor.TuneOption") @@ -86,7 +86,11 @@ TVM_REGISTER_GLOBAL("ansor.TuneOption") TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { - return AutoSchedule(task, search_policy, tune_option); + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); + + return Array{sch, return_tensors}; }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index f68e844ba776..3737f8c5d096 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -67,8 +67,8 @@ class TuneOptionNode : public Object { TVM_DEFINE_COW_OBJECT_REF(TuneOption, ObjectRef, TuneOptionNode); /*! \brief Auto schedule for a compute declaration */ -State AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuneOption tune_option); +std::pair > AutoSchedule( + SearchTask task, SearchPolicy search_policy, TuneOption tune_option); std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index 0385568894fe..2ac54d3c765b 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -19,8 +19,6 @@ import tvm from tvm import ansor -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server import tempfile from test_ansor_common import get_tiled_matmul @@ -69,26 +67,17 @@ def test_measure_local_builder_rpc_runner(): tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) - minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() - host = '0.0.0.0' - tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server(host, port=tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + measure_ctx = ansor.LocalRPCMeasureContext() + rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0 - tracker.terminate() - server.terminate() - if __name__ == "__main__": test_serialization() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index a28456574abe..5cb67dba39fe 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -18,7 +18,6 @@ """Test search policy""" import random -import os import numpy as np import tempfile @@ -33,10 +32,10 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 - A, B, C = matmul_ansor_test(N, N, N) - dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create(target) - task = ansor.SearchTask(dag, "test", tgt) + workload_key = ansor.make_workload_key_func(matmul_ansor_test, (N, N, N)) + dag = ansor.workload_key_to_dag(workload_key) + target = tvm.target.create(target) + task = ansor.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name @@ -44,35 +43,29 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, callbacks=[ansor.LogToFile(log_file)]) - state = ansor.auto_schedule(task, search_policy, + sch, args = ansor.auto_schedule(task, search_policy, tune_option=tune_option) - sch, args = dag.apply_steps_from_state(state) + inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) - print("==== Get State ====") - print(state) - print("==== Get Python Code ====") - print(dag.print_python_code_from_state(state)) + print("==== Python Code ====") + print(dag.print_python_code_from_state(inp.state)) try: - print("==== Get Lowered Stmt ====") + print("==== Lowered Stmt ====") print(tvm.lower(sch, args, simple_mode=True)) - mod = tvm.build(sch, args, tgt) + mod = tvm.build(sch, args, target) - ctx = tvm.context(target, 0) - a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) + ctx = tvm.context(str(target), 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) mod(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot( a.asnumpy(), b.asnumpy()), rtol=1e-5) print("==== Verification passed ====") except Exception: raise Exception("Error encountered with seed: %d" % (seed)) - - inp, res = ansor.best_measure_pair_in_file(log_file) - s0 = dag.infer_bound_from_state(state) - s1 = dag.infer_bound_from_state(inp.state) - assert s0 == s1 print() @@ -81,23 +74,23 @@ def test_search_basic(): def test_search_xgb_model_rpc_runner(): - with ansor.RPCRunnerWarpper() as rpc_runner: - search_common(seed=456787236, cost_model=ansor.XGBModel(), - runner=rpc_runner.runner) + measure_ctx = ansor.LocalRPCMeasureContext() + search_common(seed=456787236, cost_model=ansor.XGBModel(), + runner=measure_ctx.runner) def test_search_opencl(): if tvm.context("opencl", 0).exist: - with ansor.RPCRunnerWarpper() as rpc_runner: - search_common("opencl", 380344973, rpc_runner.runner) + measure_ctx = ansor.LocalRPCMeasureContext() + search_common("opencl", 380344973, measure_ctx.runner) else: print("OpenCL device not found, skip this test.") def test_search_cuda(): if tvm.context("cuda", 0).exist: - with ansor.RPCRunnerWarpper("cuda") as rpc_runner: - search_common("cuda", 903667810, rpc_runner.runner) + measure_ctx = ansor.LocalRPCMeasureContext() + search_common("cuda", 903667810, measure_ctx.runner) else: print("CUDA device not found, skip this test.") From cd0a516271c2d7b5f239fa601247f969929a90d3 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 10 Jun 2020 21:01:11 +0800 Subject: [PATCH 20/78] Code refine for tune_test.py & Add a pre load callback (#20) * Bug fix for tutorials * Add PreLoadMeasuredStates * Add search_callback support for task tuner * Code refine for tune_test.py * Update * Update * Update * Update * Bug fix --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/auto_schedule.py | 38 +++- python/tvm/ansor/measure.py | 40 +++- python/tvm/ansor/task_scheduler.py | 9 +- scripts/tune_test.py | 212 ++++++++---------- src/ansor/auto_schedule.cc | 23 +- src/ansor/auto_schedule.h | 11 +- .../search_policy/meta_tile_rewrite_policy.cc | 5 +- .../search_policy/meta_tile_rewrite_policy.h | 9 +- src/ansor/search_policy/search_policy.cc | 82 ++++++- src/ansor/search_policy/search_policy.h | 36 ++- src/ansor/serialization.cc | 4 + src/ansor/serialization.h | 1 + .../unittest/test_ansor_search_policy.py | 6 +- tutorials/ansor/tune_conv2d_cuda.py | 29 ++- tutorials/ansor/tune_simple_subgraph.py | 41 +--- 16 files changed, 355 insertions(+), 193 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 2e3553cf725c..1029875917aa 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,7 +29,7 @@ # Shortcut from .compute_dag import ComputeDAG -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams +from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, PreLoadMeasuredStatesCallback from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 1192e6d551e5..5b5eb4fe183b 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -69,7 +69,15 @@ def __init__(self, dag, workload_key, target, target_host=None, class SearchPolicy(Object): def continue_search(self, task, num_measure, verbose, measurer): return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) + + def set_task(self, task): + _ffi_api.SearchPolicySetTask(self, task); + def set_verbose(self, verbose): + _ffi_api.SearchPolicySetVerbose(self, verbose); + + def run_callbacks(self, callbacks): + _ffi_api.SearchPolicyRunCallbacks(self, callbacks) @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") class MetaTileRewritePolicy(SearchPolicy): @@ -117,6 +125,21 @@ def __init__(self, seed or random.randint(1, 1 << 30)) +@tvm._ffi.register_object("ansor.SearchCallback") +class SearchCallback(Object): + pass + + +@tvm._ffi.register_object("ansor.PreLoadMeasuredStatesCallback") +class PreLoadMeasuredStatesCallback(SearchCallback): + """ A SearchCallback that used for search policy to load measured hash + from the log file. + """ + def __init__(self, filename: str): + self.__init_handle_by_constructor__( + _ffi_api.PreLoadMeasuredStatesCallback, filename) + + @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): """ The options for tuning @@ -135,11 +158,13 @@ class TuneOption(Object): Builder which builds the program runner: Runner Runner which runs the program and measure time costs - callbacks: List[MeasureCallback] + measure_callbacks: List[MeasureCallback] Callback functions + pre_search_callbacks: List[SearchCallback] """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, - verbose=1, builder='local', runner='local', callbacks=None): + verbose=1, builder='local', runner='local', measure_callbacks=None, + pre_search_callbacks=None): if isinstance(builder, str): if builder == 'local': builder = LocalBuilder() @@ -152,12 +177,15 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, else: raise ValueError("Invalid builder: " + runner) - if callbacks is None: - callbacks = [] + if measure_callbacks is None: + measure_callbacks = [] + + if pre_search_callbacks is None: + pre_search_callbacks = [] self.__init_handle_by_constructor__( _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter, - verbose, builder, runner, callbacks) + verbose, builder, runner, measure_callbacks, pre_search_callbacks) def auto_schedule(workload, target=None, diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 299c004f756d..610e9529090f 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -174,6 +174,16 @@ def __init__(self, @tvm._ffi.register_object("ansor.ProgramMeasurer") class ProgramMeasurer(Object): + """ + Parameters + ---------- + builder : Builder + runner : Runner + callbacks : List[MeasureCallback] + verbose : Int + max_continuous_error : Float + """ + def __init__(self, builder: Builder, runner: Runner, callbacks: List[MeasureCallback], verbose: int, max_continuous_error: int = -1): @@ -182,6 +192,21 @@ def __init__(self, builder: Builder, runner: Runner, @tvm._ffi.register_object("ansor.RPCRunner") class RPCRunner(Runner): + """ + Parameters + ---------- + key : Str + host : Str + port : Int + priority : Int + n_parallel : Int + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + def __init__(self, key, host, port, priority=1, n_parallel=1, timeout=10, @@ -203,6 +228,19 @@ def __init__(self, key, host, port, priority=1, class LocalRPCMeasureContext: + """ A context wrapper for RPCRunner. + + Parameters + ---------- + priority : Int + n_parallel : Int + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + def __init__(self, priority=1, n_parallel=1, @@ -228,8 +266,8 @@ def __init__(self, time.sleep(0.5) def __del__(self): - self.tracker.terminate() self.server.terminate() + self.tracker.terminate() class MeasureErrorNo(object): diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 5144591d4f98..082b2d265140 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -153,7 +153,7 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol self.tune_option = tune_option if self.use_debug_measurement_simulator is None: self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, - tune_option.callbacks, tune_option.verbose) + tune_option.measure_callbacks, tune_option.verbose) self.ct = 0 self.tic = time.time() # reset num_measure_per_iter to make sure every task is tuned at least once @@ -167,6 +167,13 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol self.sequential_now_task_idx = 0 self.sequential_now_task_begin_ct = 0 + for i in range(len(self.tasks)): + search_policy = self.search_policies[i] + task = self.tasks[i] + search_policy.set_task(task) + search_policy.set_verbose(tune_option.verbose) + search_policy.run_callbacks(tune_option.pre_search_callbacks) + # do a round robin first if self.strategy != 'sequential': for i in range(len(self.tasks)): diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 68f9dfadb8d4..1f75f0dd583e 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -13,102 +13,67 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool -def make_cost_model(model_type, load_model_file, load_log_file): - if model_type == 'xgb': - model = ansor.XGBModel() - if load_model_file: - print("Load pretrained model...") - model.load(load_model_file) - elif load_log_file: - model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() +def replay_workload(wkl_key, target, target_host, log_file, + local_measure=True, device_key=None, host="0.0.0.0", + port=9190, ndk_cc=None): + inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) + if inp is None: + print("Cannot find log for: %s" % (wkl_key)) else: - raise ValueError("Invalid model: " + model_type) - return model + dag = ansor.workload_key_to_dag(inp.task.workload_key) + s, bufs = dag.apply_steps_from_state(inp.state) + + print("Found schedule for: %s" % (wkl_key)) + print(tvm.lower(s, bufs, simple_mode=True)) + if local_measure: + remote = None + else: + remote = request_remote(device_key, host, port, 1) + cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % + (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) -def tune_workload(wkl_key, target, target_host, n_trials, num_measure_per_iter, - policy, log_file, verbose, - model_type, load_model_file, load_log_file, - build_timeout, local_measure=True, device_key=None, host="0.0.0.0", - port=9190, n_parallel=1, ndk_cc=None, remeasure=True): +def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, + load_log_file, tune_option): """Tune a workload""" if False: # Debug info. Print static analysis results from the access analyzer - dag = auto_scheduler.workload_key_to_dag(wkl_key) + dag = ansor.workload_key_to_dag(wkl_key) print(dag.access_analyzer) exit() - model = make_cost_model(model_type, load_model_file, load_log_file) + if model_type == 'xgb': + model = ansor.XGBModel() + if load_model_file: + print("Load pretrained model...") + model.load(load_model_file) + elif load_log_file: + model.load_log_file(load_log_file) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) if policy == 'meta-rewrite': policy = ansor.MetaTileRewritePolicy(program_cost_model=model) elif policy == 'beam-search': policy = ansor.MetaTileRewritePolicy(program_cost_model=model, - params={'use_beam_search': 1}) + params={'use_beam_search': 1}) else: raise ValueError("Invalid search policy: " + policy) - if local_measure: - builder = ansor.LocalBuilder(build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = ansor.LocalBuilder(build_timeout, build_func='ndk') - runner = ansor.RPCRunner(device_key, host=host, port=port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel) - - tune_option = ansor.TuneOption(n_trials=n_trials, - num_measure_per_iter=num_measure_per_iter, - verbose=verbose, - builder=builder, - runner=runner, - callbacks=[ansor.LogToFile(log_file)]) s, bufs = ansor.auto_schedule(wkl_key, target=target, target_host=target_host, search_policy=policy, tune_option=tune_option) - if remeasure: - print("Found schedule:") - print(tvm.lower(s, bufs, simple_mode=True)) - print("Redo measurement for double check...") - if local_measure: - remote = None - else: - remote = request_remote(device_key, host, port, 1) - cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) - -def tune_workloads_jointly(wkl_keys, weights, joint_tuner, target, target_host, - n_trials, num_measure_per_iter, - search_policy, log_file, verbose, - model_type, load_model_file, load_log_file, - build_timeout, local_measure=True, device_key=None, - host="0.0.0.0", port=9190, n_parallel=1, ndk_cc=None): +def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, + search_policy, model_type, load_model_file, load_log_file, + tune_option): """Tune for multiple workloads jointly""" - if local_measure: - builder = ansor.LocalBuilder(timeout=build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = ansor.LocalBuilder(build_func='ndk', timeout=build_timeout) - runner = ansor.RPCRunner(device_key, host=host, port=port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel) tasks = [] for wkl_key in wkl_keys: @@ -118,78 +83,99 @@ def tune_workloads_jointly(wkl_keys, weights, joint_tuner, target, target_host, def objective_func(costs): return sum(c * w for c, w in zip(costs, weights)) - tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=joint_tuner, + tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, load_log_file=load_log_file, load_model_file=load_model_file) - search_policy = "%s.%s" % (search_policy, model_type) - tune_option = ansor.TuneOption(n_trials=n_trials, - num_measure_per_iter=num_measure_per_iter, - builder=builder, - verbose=verbose, - runner=runner, - callbacks=[ansor.LogToFile(log_file)]) tuner.tune(tune_option, search_policy) if __name__ == "__main__": parser = argparse.ArgumentParser() + # Task related options parser.add_argument("--wkl", type=str, required=True) - parser.add_argument("--n-trials", type=int, default=1000) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + # Strategy related options + parser.add_argument("--seed", type=int, default=0, help='random seed') parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--load-model", type=str) - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + # File related options + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Detailed control options + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) parser.add_argument("--host", type=str, default='0.0.0.0') parser.add_argument("--port", type=int, default=9190) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) logging.basicConfig() - logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + logging.getLogger('ansor').setLevel(logging.DEBUG) log_file = args.log_file or args.wkl + ".json" - load_log_file = args.load_log or log_file target = tvm.target.create(args.target) wkl_keys = get_workload_keys(args.wkl) - weights = get_workload_weights(args.wkl) - if args.task_scheduler == 'no': - # tune workloads one by one - for wkl_key in wkl_keys: - tune_workload(wkl_key, target, args.target_host, args.n_trials, - args.num_measure_per_iter, - args.policy, log_file, args.verbose, - args.model_type, args.load_model, load_log_file, - args.build_timeout, - args.local_measure, args.device_key, args.host, - args.port, args.n_parallel, args.ndk_cc, - remeasure=len(wkl_keys) == 1) - else: - # tune workloads jointly using JointTuner - tune_workloads_jointly(wkl_keys, weights, args.joint_tuner, - target, args.target_host, - args.n_trials, args.num_measure_per_iter, - args.policy, log_file, args.verbose, - args.model_type, args.load_model, args.load_log, - args.build_timeout, - args.local_measure, args.device_key, args.host, - args.port, args.n_parallel, args.ndk_cc) + if args.tune: + load_log_file = args.load_log or log_file + weights = get_workload_weights(args.wkl) + + builder = runner = measure_ctx = None + if args.local_measure: + builder = ansor.LocalBuilder(timeout=args.build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = args.ndk_cc + builder = ansor.LocalBuilder(timeout=args.build_timeout, build_func='ndk') + runner = ansor.RPCRunner(args.device_key, host=args.host, port=args.port, + repeat=1, min_repeat_ms=400, n_parallel=args.n_parallel) + + tune_option = ansor.TuneOption(n_trials=args.n_trials, + num_measure_per_iter=args.num_measure_per_iter, + verbose=args.verbose, + builder=builder, + runner=runner, + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStatesCallback(log_file)]) + + if args.task_scheduler == 'no': + # tune workloads one by one + for wkl_key in wkl_keys: + tune_workload(wkl_key, target, args.target_host, args.policy, + args.model_type, args.load_model, load_log_file, + tune_option) + else: + # tune workloads jointly using JointTuner + tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, + target, args.target_host, args.policy, + args.model_type, args.load_model, load_log_file, + tune_option) + if measure_ctx: + del measure_ctx + + if not args.tune or len(wkl_keys) == 1: + for wkl_key in wkl_keys: + replay_workload(wkl_key, target, args.target_host, log_file, + args.local_measure, args.device_key, args.host, + args.port, args.ndk_cc) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 3c793e5957f5..200118cf708b 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -36,7 +36,8 @@ TVM_REGISTER_NODE_TYPE(TuneOptionNode); TuneOption TuneOptionNode::make(int n_trials, int early_stopping, int num_measure_per_iter, int verbose, Builder builder, Runner runner, - Array callbacks) { + Array measure_callbacks, + Array pre_search_callbacks) { auto node = make_object(); node->n_trials = n_trials; node->early_stopping = early_stopping; @@ -44,20 +45,23 @@ TuneOption TuneOptionNode::make(int n_trials, int early_stopping, node->verbose = verbose; node->builder = std::move(builder); node->runner = std::move(runner); - node->callbacks = std::move(callbacks); + node->measure_callbacks = std::move(measure_callbacks); + node->pre_search_callbacks = std::move(pre_search_callbacks); return TuneOption(node); } -std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuneOption tune_option) { +std::pair > AutoSchedule(SearchTask task, + SearchPolicy search_policy, TuneOption tune_option) { // Search for the best schedule ProgramMeasurer measurer = ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, - tune_option->callbacks, tune_option->verbose); + tune_option->measure_callbacks, + tune_option->verbose); State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, - tune_option->num_measure_per_iter, tune_option->verbose, measurer); + tune_option->num_measure_per_iter, tune_option->verbose, measurer, + tune_option->pre_search_callbacks); return task->compute_dag.ApplySteps(state->transform_steps); } @@ -71,16 +75,17 @@ std::pair > AutoSchedule( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); return AutoSchedule(std::move(task), std::move(search_policy), - std::move(tune_option)); + std::move(tune_option)); } TVM_REGISTER_GLOBAL("ansor.TuneOption") .set_body_typed([](int n_trials, int early_stopping, int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array callbacks) { + Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { return TuneOptionNode::make(n_trials, early_stopping, num_measure_per_iter, verbose, builder, - runner, callbacks); + runner, measure_callbacks, pre_search_callbacks); }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 3737f8c5d096..4e70ac0b577a 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -28,6 +28,7 @@ #include #include #include "measure.h" +#include "search_policy/search_policy.h" namespace tvm { namespace ansor { @@ -45,7 +46,9 @@ class TuneOptionNode : public Object { Builder builder; // Builder which builds the program Runner runner; // Runner which runs the program and measure time // costs - Array callbacks; // Callback functions + Array measure_callbacks; // MeasureCallback functions + Array pre_search_callbacks; // SearchCallback functions + // run before search void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("n_trials", &n_trials); @@ -54,12 +57,14 @@ class TuneOptionNode : public Object { v->Visit("verbose", &verbose); v->Visit("builder", &builder); v->Visit("runner", &runner); - v->Visit("callbacks", &callbacks); + v->Visit("measure_callbacks", &measure_callbacks); + v->Visit("pre_search_callbacks", &pre_search_callbacks); } static TuneOption make(int n_trials, int early_stopping, int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array callbacks); + Runner runner, Array measure_callbacks, + Array pre_search_callbacks); static constexpr const char* _type_key = "ansor.TuneOption"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index f086a8879abb..0a9f97ab9170 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -58,12 +58,15 @@ SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) { + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) { std::vector best_states, random_states; cur_task_ = task; verbose_ = verbose; num_measure_per_iter_ = num_measure_per_iter; + RunCallbacks(pre_search_callbacks); + if (n_trials <= 1) { // no measurement is allowed SearchOneRound(&best_states, 0, &random_states); CHECK_GT(best_states.size(), 0); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index f92813b11273..8cf61b4d1e11 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -63,7 +63,8 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { // Return the best state State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) final; + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) final; // Continue search. This is used by JointTuner std::pair, Array > ContinueSearchOneRound( @@ -74,8 +75,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); - SearchTask cur_task_; // The current task - protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, @@ -100,12 +99,8 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator - int verbose_; // Verbose level (0 means silent) int num_measure_per_iter_; // The number of states to measure per iteration - // The set of the already measured states. We store the string format for redundancy check - std::unordered_set measured_states_set_; - // The array of already measured states. std::vector measured_states_vector_; diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index f3072fda4956..b2ba27bfc6ba 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,27 +23,91 @@ */ #include "search_policy.h" + #include +#include "../serialization.h" + namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesCallbackNode); + +void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { + LogReader reader = LogReaderNode::make(log_file); + const auto& res = reader->ReadLines(-1); + if (res.first.size()) { + std::vector measured_states; + for (const auto& inp : res.first) { + if (inp->task->workload_key == cur_task_->workload_key && + inp->task->target->target_name.compare( + cur_task_->target->target_name) == 0) { + State state = cur_task_->compute_dag.GetInitState(); + state.CopyOnWrite()->transform_steps = inp->state->transform_steps; + state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag); + measured_states.push_back(std::move(state)); + } + } + cur_task_->compute_dag.InferBound(&measured_states); + for (auto state : measured_states) { + measured_states_set_.insert(state.ToStr()); + } + + StdCout(verbose_) << "Measured States Set: " + << measured_states_set_.size() + << " state hashes loaded from " << log_file << std::endl; + } +} + +void SearchPolicyNode::RunCallbacks(const Array& callbacks) { + if (callbacks.defined() && callbacks.size()) { + PrintTitle("Process search callbacks", verbose_); + for (const auto& callback : callbacks) { + callback->callback(this); + } + } +} + +SearchCallback PreLoadMeasuredStatesCallbackNode::make(std::string filename) { + auto node = make_object(); + node->filename = std::move(filename); + return SearchCallback(node); +} + +void PreLoadMeasuredStatesCallbackNode::callback(SearchPolicyNode* policy) { + policy->PreLoadMeasuredStates(filename); +} // Search Policy TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") -.set_body([](TVMArgs args, TVMRetValue *ret) { - SearchPolicy policy = args[0]; - SearchTask task = args[1]; - int num_measure = args[2]; - int verbose = args[3]; - ProgramMeasurer measurer = args[4]; - +.set_body_typed([](SearchPolicy policy, SearchTask task, int num_measure, + int verbose, ProgramMeasurer measurer) { Array inputs; Array results; - std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, + verbose, measurer); + return Array{inputs, results}; +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") +.set_body_typed([](SearchPolicy policy, Array callbacks) { + policy->RunCallbacks(callbacks); +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") +.set_body_typed([](SearchPolicy policy, SearchTask task) { + policy->cur_task_ = task; +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") +.set_body_typed([](SearchPolicy policy, int verbose) { + policy->verbose_ = verbose; +}); - *ret = Array{inputs, results}; +TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStatesCallback") +.set_body_typed([](std::string filename) { + return PreLoadMeasuredStatesCallbackNode::make(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f2071deab447..0d7ebe94c14f 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -26,6 +26,7 @@ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #include +#include #include #include #include @@ -36,17 +37,45 @@ namespace tvm { namespace ansor { class SearchPolicy; +class SearchPolicyNode; + +class SearchCallbackNode : public Object { + public: + virtual void callback(SearchPolicyNode* policy) = 0; + static constexpr const char *_type_key = "ansor.SearchCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); + +class PreLoadMeasuredStatesCallbackNode : public SearchCallbackNode { + public: + std::string filename; + + static SearchCallback make(std::string filename); + + void callback(SearchPolicyNode* policy) final; + + static constexpr const char *_type_key = "ansor.PreLoadMeasuredStatesCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesCallbackNode, SearchCallbackNode); +}; /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: virtual State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) = 0; + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) = 0; virtual std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; + void PreLoadMeasuredStates(const std::string& log_file); + void RunCallbacks(const Array& callbacks); + + SearchTask cur_task_; // The current task + int verbose_; // Verbose level (0 means silent) + // Dict keys static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; static constexpr const char* always_unroll_key = "ansor_always_unroll"; @@ -63,6 +92,11 @@ class SearchPolicyNode : public Object { static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); + + protected: + // The set of the already measured states. + // We store the string format for redundancy check + std::unordered_set measured_states_set_; }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 76f5d4449001..b03acb1edc3c 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -499,6 +499,10 @@ LogReader LogReaderNode::make(std::string filename) { return LogReader(node); } +LogReaderNode::~LogReaderNode() { + infile.close(); +} + bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { std::string log_version; diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index a12760bb3acc..d877717db9cb 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -58,6 +58,7 @@ class LogReaderNode : public Object { std::ifstream infile; static LogReader make(std::string filename); + ~LogReaderNode(); /*! \brief Read next line in the log file * \return Whether the read is successful */ diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 5cb67dba39fe..6fe1012e6629 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -42,9 +42,9 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, - callbacks=[ansor.LogToFile(log_file)]) - sch, args = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) + measure_callbacks=[ansor.LogToFile(log_file)]) + sch, args = ansor.auto_schedule(task, search_policy=search_policy, + tune_option=tune_option) inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) print("==== Python Code ====") diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 82a5e8572ba2..caa040d1b3bc 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -110,11 +110,11 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): log_file = "conv2d_nchw.json" seed = 0 random.seed(seed) -cost_model = ansor.XGBModel() +cost_model = ansor.XGBModel(seed=seed) search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) ######################################################################### -# The :code:`ansor.RPCRunnerWarpper` is used to create a RPC runner environment, +# The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. # # Use local gpu, measure 10 times for every schedule to reduce variance. The timeout # for each running is set to 4 seconds. @@ -123,15 +123,24 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # will be filtered out. It's fine to see "Encountered errors during feature extraction." # in the tuning logs. -with ansor.RPCRunnerWarpper("cuda", repeat=3, min_repeat_ms=100, timeout=4) as rpc_runner: - tune_option = ansor.TuneOption(n_trials=20, - runner=rpc_runner.runner, - callbacks=[ansor.LogToFile(log_file)]) - state = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) - print(state) +measure_ctx = ansor.LocalRPCMeasureContext(repeat=3, min_repeat_ms=100, timeout=4) +tune_option = ansor.TuneOption(n_trials=20, + runner=measure_ctx.runner, + measure_callbacks=[ansor.LogToFile(log_file)]) +s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) + +print("==== Get Lowered Stmt ====") +print(tvm.lower(s, arg_bufs, simple_mode=True)) + +# Release the RPC runner environment +del measure_ctx ######################################################################### +# From the example lower result showed above, we can see that Ansor has tried +# techniques such as `Shared Memory Cooperative Fetching`, `Kernel Fusion`, +# `Axis unroll`, `Axis Vectorize` and so on. There is no need for users to care +# about the details, and Ansor will catch them well. +# # Finally we can directly use the returned result to get the generated schedule, # while in the following tutorial we'll show how to inspect the best config from # log file, check correctness, and measure running time. @@ -160,5 +169,5 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # Evaluate running time. Here we choose a large repeat number (400) to reduce the noise # and the overhead of kernel launch. You can also use nvprof to validate the result. evaluator = func.time_evaluator(func.entry_name, ctx, number=400) -print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean) +print('Time cost of this operator: %f s' % evaluator(a_tvm, w_tvm, c_tvm).mean) diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 2af33c1e88ba..fedbb399d0cf 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -113,8 +113,8 @@ def matmul_add(N, L, M, dtype): # When proposing the next batch of schedules, Ansor can take different cost models to # guide the schedule generating process. # -# * :any:`RandomModel`: Generate and take new schedule randomly -# * :any:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step +# * :code:`RandomModel`: Generate and take new schedule randomly +# * :code:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step # # XGBModel can explore more efficiently and find better schedules. @@ -130,7 +130,7 @@ def matmul_add(N, L, M, dtype): # # Then we create the :code:`tvm.target` and a tuning task. -N, L, M = 64, 64, 64 +N, L, M = 128, 128, 128 A, B, C, D = matmul_add(N, L, M, 'float32') dag = ansor.ComputeDAG([A, B, C, D]) @@ -148,9 +148,6 @@ def matmul_add(N, L, M, dtype): # you can do more trials according to your time budget. # The :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. -# -# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high -# performance schedule for the target subgraph automatically. log_file = "matmul_add.json" @@ -160,34 +157,20 @@ def matmul_add(N, L, M, dtype): search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, - callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)]) -state = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) -print(state) - -######################################################################### -# Finally we apply the history best to be a TVM schedule. -# -# We can call the function :code:`apply_steps_from_state` directly using the returned -# :code:`state` structure. -# :code:`state` can also be used to print out the user friendly Python code on demand. -# -# And since we've record the runing results to file, we can also use the following -# code to reply the best schedule from the log file: -# .. code-block:: c -# -# inp, res = ansor.best_measure_pair_in_file(log_file) -# state = inp.state -# s, arg_bufs = dag.apply_steps_from_state(state) +################################################################ +# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high +# performance schedule for the target subgraph automatically. # -# With the :code:`state` above, we have lowered result and its python code: +# The returned result will be a :code:`te.schedule` and a list of :code:`te.Tensor`, +# which can be used as the input of :code:`tvm.lower` or :code:`tvm.build`. + +s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, + tune_option=tune_option) -s, arg_bufs = dag.apply_steps_from_state(state) print("==== Get Lowered Stmt ====") print(tvm.lower(s, arg_bufs, simple_mode=True)) -print("==== Get Python Code ====") -print(dag.print_python_code_from_state(state)) ######################################################################### # Check the correctness to make sure we generate a right schedule. From 3a24e49ee7b7e5d3b09e2fb6062c45923a95abd3 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 11 Jun 2020 19:09:24 +0800 Subject: [PATCH 21/78] Add python custom sketch rule (#21) * Add custom sketch rule * Bug fix --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/auto_schedule.py | 46 +++++-- scripts/tune_test.py | 2 +- .../search_policy/meta_tile_rewrite_policy.cc | 116 ++++++++++++++---- .../search_policy/meta_tile_rewrite_policy.h | 21 +++- src/ansor/search_policy/search_policy.cc | 12 +- src/ansor/search_policy/search_policy.h | 10 +- .../unittest/test_ansor_search_policy.py | 68 +++++++++- 8 files changed, 230 insertions(+), 48 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 1029875917aa..845d1b5e477d 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,7 +29,8 @@ # Shortcut from .compute_dag import ComputeDAG -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, PreLoadMeasuredStatesCallback +from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ + PreLoadMeasuredStates, PreAddCustomRule from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 5b5eb4fe183b..e1a0711a80be 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -67,14 +67,16 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): + """ The base search policy class + """ def continue_search(self, task, num_measure, verbose, measurer): return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) - + def set_task(self, task): - _ffi_api.SearchPolicySetTask(self, task); + _ffi_api.SearchPolicySetTask(self, task) def set_verbose(self, verbose): - _ffi_api.SearchPolicySetVerbose(self, verbose); + _ffi_api.SearchPolicySetVerbose(self, verbose) def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) @@ -130,14 +132,39 @@ class SearchCallback(Object): pass -@tvm._ffi.register_object("ansor.PreLoadMeasuredStatesCallback") -class PreLoadMeasuredStatesCallback(SearchCallback): +@tvm._ffi.register_object("ansor.PreLoadMeasuredStates") +class PreLoadMeasuredStates(SearchCallback): """ A SearchCallback that used for search policy to load measured hash from the log file. + + Parameters + ---------- + filename: Str """ def __init__(self, filename: str): self.__init_handle_by_constructor__( - _ffi_api.PreLoadMeasuredStatesCallback, filename) + _ffi_api.PreLoadMeasuredStates, filename) + + +@tvm._ffi.register_object("ansor.PreAddCustomRule") +class PreAddCustomRule(SearchCallback): + """ + A SearchCallback for MetaTileRewritePolicy that allowing users to add + custom sketch rule. + + Notice: This is an advanced feature, make sure you're clear how it + works and this should only be used in MetaTileRewritePolicy. + + Parameters + ---------- + meet_condition_func: Function + A function with `(policy, state, stage_id) -> int` + apply_func: Function + A function with `(policy, state, stage_id) -> [[State, int], ...]` + """ + def __init__(self, meet_condition_func, apply_func): + self.__init_handle_by_constructor__( + _ffi_api.PreAddCustomRule, meet_condition_func, apply_func) @tvm._ffi.register_object("ansor.TuneOption") @@ -159,8 +186,13 @@ class TuneOption(Object): runner: Runner Runner which runs the program and measure time costs measure_callbacks: List[MeasureCallback] - Callback functions + Callback functions called after each measure + Candidates: + - ansor.LogToFile pre_search_callbacks: List[SearchCallback] + Callback functions called before the search process + Candidates: + - ansor.PreLoadMeasuredStates """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 1f75f0dd583e..08f0cc19ade2 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -157,7 +157,7 @@ def objective_func(costs): builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStatesCallback(log_file)]) + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 0a9f97ab9170..5703e17ba29f 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -41,7 +41,8 @@ namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode); +TVM_REGISTER_NODE_TYPE(MetaTileRewritePolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreAddCustomRuleNode); // All possible candidates for auto_unroll const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; @@ -241,7 +242,7 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, // Synthesize meta structure std::vector meta_structures; - SynthesizeMetaStructure(&meta_structures); + GenerateMetaSketch(&meta_structures); // PrintAllStates(meta_structures); // exit(0); @@ -272,8 +273,8 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in meta structure synthesis -class StructureSynthesisRule { +// The baseclass of derivation rules used in meta sketch generation +class SketchGenerationRule { public: enum ConditionEnum { kPass, kApply, kApplyAndSkipRest @@ -345,7 +346,7 @@ static inline bool ShouldAlwaysBeInlined( } // The rule that inlines simple elementwise ops -class RuleAlwaysInline : public StructureSynthesisRule { +class RuleAlwaysInline : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -362,7 +363,7 @@ class RuleAlwaysInline : public StructureSynthesisRule { }; // The rule that simply skip the current stage -class RuleSkipStage : public StructureSynthesisRule { +class RuleSkipStage : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -387,7 +388,7 @@ class RuleSkipStage : public StructureSynthesisRule { }; // The rule that performs multi-level tiling -class RuleMultiLevelTiling : public StructureSynthesisRule { +class RuleMultiLevelTiling : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -413,7 +414,7 @@ class RuleMultiLevelTiling : public StructureSynthesisRule { }; // The rule that performs multi-level tiling and fuses later consumers -class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule { +class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -482,7 +483,7 @@ class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule { }; // The rule that adds a cache write stage -class RuleAddCacheWrite : public StructureSynthesisRule { +class RuleAddCacheWrite : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -515,7 +516,7 @@ class RuleAddCacheWrite : public StructureSynthesisRule { // The rule that adds a cache read stage // Mainly used for GPU cooperative fetching // Currently only support 1 to 1 match cache read -class RuleAddCacheRead : public StructureSynthesisRule { +class RuleAddCacheRead : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -546,7 +547,7 @@ class RuleAddCacheRead : public StructureSynthesisRule { }; // The rule that adds rfactor stage -class RuleAddRfactor : public StructureSynthesisRule { +class RuleAddRfactor : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -610,7 +611,7 @@ class RuleAddRfactor : public StructureSynthesisRule { } }; -void MetaTileRewritePolicyNode::SynthesizeMetaStructure( +void MetaTileRewritePolicyNode::GenerateMetaSketch( std::vector* out_states) { State init_state = cur_task_->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = @@ -634,18 +635,22 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure( static RuleAddCacheWrite rule_add_cache_write_stage; static RuleAddCacheRead rule_add_cache_read_stage; static RuleAddRfactor rule_add_rfactor; - // We may apply and skip the rest when processing some rules, - // should take care of the rule vector order here - static std::vector all_rules { - &rule_always_inline, &rule_add_cache_write_stage, - &rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling, - &rule_add_rfactor, &rule_skip_stage - }; - if (IS_GPU(cur_task_)) { - // Try cache read first before cache write - all_rules.insert(all_rules.begin() + 1, &rule_add_cache_read_stage); + if (sketch_rules.empty()) { + // We may apply and skip the rest when processing some rules, + // should take care of the rule vector order here + sketch_rules.push_back(&rule_always_inline); + sketch_rules.push_back(&rule_add_cache_write_stage); + sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); + sketch_rules.push_back(&rule_multi_level_tiling); + sketch_rules.push_back(&rule_add_rfactor); + sketch_rules.push_back(&rule_skip_stage); + if (IS_GPU(cur_task_)) { + // Try cache read first before cache write + sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage); + } + // TODO(xian): Add a new rule to try combination of multi-level + // tiling + rfactor } - // TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor // Derivation rule based synthesizer while (!pnow->empty()) { @@ -661,15 +666,15 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure( } // Try all derivation rules - for (const auto& rule : all_rules) { + for (const auto& rule : sketch_rules) { auto rule_check = rule->MeetCondition(this, state, stage_id); - if (rule_check > StructureSynthesisRule::ConditionEnum::kPass) { + if (rule_check > SketchGenerationRule::ConditionEnum::kPass) { for (const auto& pair : rule->Apply(this, state, stage_id)) { cur_stage_id_map[pair.first] = pair.second; pnext->push_back(pair.first); } // Skip the reset rules - if (rule_check == StructureSynthesisRule::ConditionEnum::kApplyAndSkipRest) { + if (rule_check == SketchGenerationRule::ConditionEnum::kApplyAndSkipRest) { break; } } @@ -1444,6 +1449,60 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( << std::fixed << std::setprecision(2) << duration << std::endl; } +class RuleCustomSketch : public SketchGenerationRule { + public: + RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : + meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} + + inline ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + auto ret = meet_condition_func_( + tvm::runtime::GetRef(policy), state, stage_id); + if (ret.type_code() == 0) { + return ConditionEnum(static_cast(ret)); + } else { + return kApplyAndSkipRest; + } + } + + inline std::vector > Apply( + const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + std::vector > ret; + + Array> apply_ret = apply_func_( + tvm::runtime::GetRef(policy), state, stage_id); + + for (const auto& item : apply_ret) { + CHECK_EQ(item.size(), 2); + State state = Downcast(item[0]); + auto next = item[1].as(); + ret.emplace_back(state, next->value); + } + return ret; + } + + private: + PackedFunc meet_condition_func_; + PackedFunc apply_func_; +}; + +SearchCallback PreAddCustomRuleNode::make(PackedFunc meet_condition_func, + PackedFunc apply_func) { + auto node = make_object(); + node->meet_condition_func = meet_condition_func; + node->apply_func = apply_func; + return SearchCallback(node); +} + +void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto meta_policy = dynamic_cast(policy); + meta_policy->sketch_rules.emplace_back( + new RuleCustomSketch(meet_condition_func, apply_func)); + StdCout(policy->verbose_) << "Custom sketch rule added." << std::endl; +} + TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") .set_body_typed([](CostModel program_cost_model, Map params, @@ -1451,5 +1510,10 @@ TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); }); +TVM_REGISTER_GLOBAL("ansor.PreAddCustomRule") +.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { + return PreAddCustomRuleNode::make(meet_condition_func, apply_func); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 8cf61b4d1e11..befc002b6aa2 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -38,6 +38,8 @@ namespace tvm { namespace ansor { +class SketchGenerationRule; + /*! Multi stage search policy */ class MetaTileRewritePolicyNode: public SearchPolicyNode { public: @@ -54,6 +56,7 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; + std::vector sketch_rules; static SearchPolicy make(CostModel program_cost_model, Map params, @@ -87,7 +90,7 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { int num_random_states, std::vector* random_states); // Synthesize meta tiling structure without tile size - void SynthesizeMetaStructure(std::vector* out_states); + void GenerateMetaSketch(std::vector* out_states); // Sample init population void SampleInitPopulation(const std::vector& meta_structures, @@ -107,6 +110,22 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { // The throughputs of already measured states std::vector measured_states_throughputs_; }; +TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); + +class PreAddCustomRuleNode : public SearchCallbackNode { + public: + // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? + PackedFunc meet_condition_func; + PackedFunc apply_func; + + static SearchCallback make(PackedFunc meet_condition_func, + PackedFunc apply_func); + + void callback(SearchPolicyNode* policy) final; + + static constexpr const char *_type_key = "ansor.PreAddCustomRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreAddCustomRuleNode, SearchCallbackNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index b2ba27bfc6ba..d52b868e180d 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -32,7 +32,7 @@ namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesCallbackNode); +TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode); void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { LogReader reader = LogReaderNode::make(log_file); @@ -69,13 +69,13 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } } -SearchCallback PreLoadMeasuredStatesCallbackNode::make(std::string filename) { - auto node = make_object(); +SearchCallback PreLoadMeasuredStatesNode::make(std::string filename) { + auto node = make_object(); node->filename = std::move(filename); return SearchCallback(node); } -void PreLoadMeasuredStatesCallbackNode::callback(SearchPolicyNode* policy) { +void PreLoadMeasuredStatesNode::callback(SearchPolicyNode* policy) { policy->PreLoadMeasuredStates(filename); } @@ -105,9 +105,9 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") policy->verbose_ = verbose; }); -TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStatesCallback") +TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStates") .set_body_typed([](std::string filename) { - return PreLoadMeasuredStatesCallbackNode::make(filename); + return PreLoadMeasuredStatesNode::make(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 0d7ebe94c14f..2dfbd9429648 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -47,7 +47,7 @@ class SearchCallbackNode : public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -class PreLoadMeasuredStatesCallbackNode : public SearchCallbackNode { +class PreLoadMeasuredStatesNode : public SearchCallbackNode { public: std::string filename; @@ -55,8 +55,8 @@ class PreLoadMeasuredStatesCallbackNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreLoadMeasuredStatesCallback"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesCallbackNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreLoadMeasuredStates"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesNode, SearchCallbackNode); }; /*! \brief The base class for search policy */ @@ -76,6 +76,10 @@ class SearchPolicyNode : public Object { SearchTask cur_task_; // The current task int verbose_; // Verbose level (0 means silent) + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cur_task", &cur_task_); + } + // Dict keys static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; static constexpr const char* always_unroll_key = "ansor_always_unroll"; diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 6fe1012e6629..b86dfa95f9bd 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -27,7 +27,8 @@ from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', - cost_model=ansor.RandomModel(), n_trials=2): + cost_model=ansor.RandomModel(), n_trials=2, params=None, + pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) @@ -40,9 +41,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + search_policy = ansor.MetaTileRewritePolicy(cost_model, params=params, + seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, - measure_callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=pre_search_callbacks) sch, args = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) @@ -95,8 +98,67 @@ def test_search_cuda(): print("CUDA device not found, skip this test.") +def test_search_custom_sketch_rule(): + def meet_condition_func(meta_policy, state, stage_id): + # Apply and Skip the Rest if this function does not return + pass + + # Expecting: + # i.0 + # i.1 + # i.2 + # j.0 + # j.1 + # ax0 + # ax1 + # B.global + # j.2 + # k + # C + def apply_func1(meta_policy, state, stage_id): + # Stage by stage way + ret = [] + if stage_id == 2: + state = ansor.loop_state.State(state) + state.split(2, state.stages[2].iters[0], [4, 4]) + state.split(2, state.stages[2].iters[3], [4, 4]) + ret.append([state.state_object, stage_id - 1]) + elif stage_id == 1: + state = ansor.loop_state.State(state) + state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state.compute_at(2, 3, state.stages[3].iters[4]) + ret.append([state.state_object, stage_id - 1]) + else: + ret.append([state, stage_id - 1]) + return ret + + def apply_func2(meta_policy, state, stage_id): + # More template like way + ret = [] + state = ansor.loop_state.State(state) + + state.split(2, state.stages[2].iters[0], [4, 4]) + state.split(2, state.stages[2].iters[3], [4, 4]) + state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state.compute_at(2, 3, state.stages[3].iters[4]) + + ret.append([state.state_object, -1]) + return ret + + measure_ctx = ansor.LocalRPCMeasureContext() + search_common(seed=887823438, runner=measure_ctx.runner, + pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, + apply_func1)], + params={'disable_change_compute_location': 1}) + search_common(seed=887823438, runner=measure_ctx.runner, + pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, + apply_func2)], + params={'disable_change_compute_location': 1}) + + if __name__ == "__main__": test_search_basic() test_search_xgb_model_rpc_runner() test_search_opencl() test_search_cuda() + test_search_custom_sketch_rule() From a155c1f46fdbbe44d2189b790313ae16cc42ce52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Minmin=20Sun=20=28=E5=AD=99=E6=95=8F=E6=95=8F=29?= Date: Fri, 12 Jun 2020 16:25:26 +0800 Subject: [PATCH 22/78] Ansor Relay Integration (without layout rewrite) (#22) * relay integration --- python/tvm/ansor/__init__.py | 10 +- python/tvm/ansor/compute_dag.py | 11 + python/tvm/ansor/dispatcher.py | 518 ++++++++++++++++++++++++++ python/tvm/ansor/env.py | 8 + python/tvm/ansor/relay_integration.py | 209 +++++++++++ python/tvm/ansor/serialization.py | 3 + python/tvm/ansor/topi_integration.py | 215 +++++++++++ scripts/tune_network.py | 497 ++++++++++++++++++++++++ topi/python/topi/ansor.py | 95 +++++ topi/python/topi/arm_cpu/__init__.py | 5 + topi/python/topi/generic/__init__.py | 5 + topi/python/topi/x86/__init__.py | 5 + 12 files changed, 1579 insertions(+), 2 deletions(-) create mode 100644 python/tvm/ansor/dispatcher.py create mode 100644 python/tvm/ansor/env.py create mode 100644 python/tvm/ansor/relay_integration.py create mode 100644 python/tvm/ansor/topi_integration.py create mode 100644 scripts/tune_network.py create mode 100644 topi/python/topi/ansor.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 845d1b5e477d..6ea8a0ce904f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,14 +28,20 @@ from . import task_scheduler # Shortcut -from .compute_dag import ComputeDAG +from .compute_dag import ComputeDAG, LayoutRewriteLevel, gen_schedule from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ PreLoadMeasuredStates, PreAddCustomRule from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel -from .serialization import LogToFile, LogReader, best_measure_pair_in_file, write_measure_records_to_file +from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ + load_from_file, write_measure_records_to_file from .workload_registry import register_auto_scheduler_workload_func, \ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler +from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ + FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext +from .topi_integration import register_topi_schedule, TaskExtractEnv +from .relay_integration import extract_from_program, extract_from_multiple_program, \ + finish_layout_rewrite diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 0b51ebb402cc..0c8aa2055482 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -19,6 +19,7 @@ import tvm._ffi from tvm.runtime import Object +from tvm import te from .loop_state import State from . import _ffi_api @@ -88,3 +89,13 @@ def infer_bound_from_state(self, state): state : StateObject """ return _ffi_api.ComputeDAGInferBoundFromState(self, state) + +def gen_schedule(state, bufs): + if not state or not state.complete: + return te.create_schedule([x.op for x in bufs]) + else: + dag = ComputeDAG(bufs) + # only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, + # since kernel layout has already been rewritten in relay pass + schedule, _ = dag.apply_steps_from_state(state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) + return schedule diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py new file mode 100644 index 000000000000..2f00c355d285 --- /dev/null +++ b/python/tvm/ansor/dispatcher.py @@ -0,0 +1,518 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Template dispatcher module. + +A dispatcher is a function that can contains multiple behaviors. +Its specific behavior is can be controlled by DispatchContext. + +DispatchContext is used in two ways, usually via different implementation +of the DispatchContext base class. + +- During search, we can use it to pass the current proposal from tuner. +- During evaluation, we can use it to set pick the best policy. +""" +# pylint: disable=invalid-name + +from __future__ import absolute_import as _abs + +import logging + +import numpy as np +from decorator import decorate + +from tvm import target as _target +from tvm.tir.expr import StringImm, FloatImm + +from .loop_state import State, StateObject + +logger = logging.getLogger('auto_scheduler') + + +class DispatchContext(object): + """ + Base class of dispatch context. + + DispatchContext enables the target and workload + specific dispatch mechanism for templates. + """ + current = None + + def __init__(self): + self._old_ctx = DispatchContext.current + + def query(self, target, workload): + """ + Query the context to get the specific config for a template. + If cannot find the result inside this context, this function will query it + from the upper contexts. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : State or str + The specific state for auto scheduler. + """ + ret = self._query_inside(target, workload) + #if ret is None: + # ret = self._old_ctx.query(target, workload) + return ret + + def update(self, target, workload, cfg): + """ + Update context with a specific config. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + cfg : State or str + The specific state for auto scheduler. + + Note + ---- + This interface is for cases when TVM decides to replace an operator in the graph. + For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW` + convolution with `NCHW[x]c` implementation on x86 CPUs. + Thus in TOPI, we first query schedule using original `NCHW` workload, + then update the dispatcher with the new `NCHW[x]c` workload. + So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using + its own workload directly. + + .. code-block:: python + + @conv2d_alter_layout.register("cpu") + def _alter_conv2d_layout(attrs, inputs, tinfo): + workload = get_conv2d_workload(...) + dispatch_ctx = auto_scheduler.DispatchContext.current + target = tvm.target.current_target() + config = dispatch_ctx.query(target, workload) + + # Get conv2d_NCHWc workload from config + # new_workload = ... + # new_inputs = ... + # new_attrs = ... + + # Store altered operator's config + dispatch_ctx.update(target, new_workload, config) + return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs) + + We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc` + share the same schedule parameters. + One can construct a new `State` if this is not the case. + """ + raise NotImplementedError() + + def _query_inside(self, target, workload): + """ + Query the context to get the specific config for a template. + This function only query config inside this context. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : State or str + The specific state for auto scheduler. + """ + raise NotImplementedError() + + def __enter__(self): + self._old_ctx = DispatchContext.current + DispatchContext.current = self + return self + + def __exit__(self, ptype, value, trace): + DispatchContext.current = self._old_ctx + + +def dispatcher(fworkload): + """Wrap a workload dispatcher function. + + Parameters + ---------- + fworkload : function + The workload extraction function from arguments. + + Returns + ------- + fdispatcher : function + A wrapped dispatcher function, which will + dispatch based on DispatchContext and + the current workload. + """ + dispatch_dict = {} + func_name = fworkload.__name__ + + def register(key, func=None, override=False): + """Register template function. + + Parameters + ---------- + key : str or List of str + The template key to identify the template + under this dispatcher. + func : function + The function to be registered. + The first argument of the function is always + cfg returned by DispatchContext, + the rest arguments are the same as the fworkload. + override : bool + Whether override existing registration. + + Returns + ------- + The register function if necessary. + """ + if isinstance(key, str): + key = [key] + + def _do_reg(myf): + for x in key: + if x in dispatch_dict and not override: + raise ValueError( + "Key %s is already registered for %s" % (x, func_name)) + dispatch_dict[x] = myf + return myf + + if func: + return _do_reg(func) + return _do_reg + + def dispatch_func(func, *args, **kwargs): + """The wrapped dispatch function""" + tgt = _target.current_target() + workload = func(*args, **kwargs) + cfg = DispatchContext.current.query(tgt, workload) + return dispatch_dict['direct'](cfg, *args, **kwargs) + + fdecorate = decorate(fworkload, dispatch_func) + fdecorate.register = register + return fdecorate + + +class ApplyConfig(DispatchContext): + """Apply a deterministic config entity for all queries. + + Parameters + ---------- + config : State + The specific state for auto scheduler. + """ + def __init__(self, config): + super(ApplyConfig, self).__init__() + self._config = config + self.workload = None + + def _query_inside(self, target, workload): + """Override query""" + self.workload = workload + return self._config + + def update(self, target, workload, cfg): + """Override update""" + self.workload = workload + self._config = cfg + + +class ApplyHistoryBest(DispatchContext): + """ + Apply the history best config + + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + n_lines: int (optional) + if it is not None, only load the first `n_lines` lines of log + """ + def __init__(self, records, n_lines=None): + super(ApplyHistoryBest, self).__init__() + + self.best_by_targetkey = {} + self.best_by_model = {} + self._best_user_defined = {} + + if records: + self.load(records, n_lines) + + def load(self, records, n_lines=None): + """Load records to this dispatch context + + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + n_lines: int (optional) + if it is not None, only load the first `n_lines` lines of log + """ + from pathlib import Path + from . import load_from_file + + if isinstance(records, Path): + records = str(records) + + if isinstance(records, str): + records = load_from_file(records) + if not records: + return + + best_by_targetkey = self.best_by_targetkey + best_by_model = self.best_by_model + + counter = 0 + for inp, res in records: + if n_lines is not None and counter >= n_lines: + break + counter += 1 + if res.error_no != 0: + continue + + # use target keys in tvm target system as key to build best map + for k in inp.task.target.keys: + key = (k, inp.task.workload_key) + if key not in best_by_targetkey: + best_by_targetkey[key] = (inp, res) + else: + _, other_res = best_by_targetkey[key] + other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] + costs = [x.value for x in res.costs if isinstance(x, FloatImm)] + if np.mean(other_costs) > np.mean(costs): + best_by_targetkey[key] = (inp, res) + + # use model as key to build best map + key = (inp.task.target.model, inp.task.workload_key) + if key not in best_by_model: + if inp.task.target.model != 'unknown': + best_by_model[key] = (inp, res) + else: + _, other_res = best_by_model[key] + other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] + costs = [x.value for x in res.costs if isinstance(x, FloatImm)] + if np.mean(other_costs) > np.mean(costs): + best_by_model[key] = (inp, res) + + logger.debug("Finish loading %d records", counter) + + def _query_inside(self, target, workload): + if target is None: + raise RuntimeError("Need a target context to find the history best. " + "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" + " above the dispatcher call. So does other target. ") + + # first try matching by model + key = (target.model, workload) + if key in self._best_user_defined: + return self._best_user_defined[key] + if key in self.best_by_model: + return self.best_by_model[key][0].state + + # then try matching by target key + for k in target.keys: + key = (k, workload) + if key in self._best_user_defined: + return self._best_user_defined[key] + if key in self.best_by_targetkey: + return self.best_by_targetkey[key][0].state + + return None + + def update(self, target, workload, state): + model = target.model + key = (model, workload) + self._best_user_defined[key] = state + + for k in target.keys: + key = (k, workload) + self._best_user_defined[key] = state + + +class BlockingEmptyContext(DispatchContext): + """ + An empty context which returns emtpy State() for all queries. + This also blocks the queries, so the queries won't affect the global FallbackContext. + """ + def __init__(self): + super(BlockingEmptyContext, self).__init__() + + def query(self, target, workload): + #return StateObject() + return None + + +class FallbackContext(DispatchContext): + """ + A fallback dispatch context. + + Any tunable template can be called under this context. + This is the root context. + """ + + def __init__(self): + super(FallbackContext, self).__init__() + self.memory = {} + self.silent = False + + # a set to prevent print duplicated message + self.messages = set() + + def _query_inside(self, target, workload): + key = (str(target), workload) + if key in self.memory: + return self.memory[key] + + if not self.silent: + msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\ + "is used, which may bring great performance regression." % (target, workload) + if msg not in self.messages: + self.messages.add(msg) + logger.warning(msg) + #cfg = StateObject() + cfg = None + + # cache this config + self.memory[key] = cfg + return cfg + + def clear_cache(self, target, workload): + """Clear fallback cache. Pass the same argument as _query_inside to this function + to clean the cache. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + """ + key = (str(target), workload) + if key in self.memory: + del self.memory[key] + + def update(self, target, workload, cfg): + key = (str(target), workload) + self.memory[key] = cfg + + +DispatchContext.current = FallbackContext() + + +def clear_fallback_cache(target, workload): + """Clear fallback cache. Pass the same argument as _query_inside to this function + to clean the cache. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Note + ---- + This is used in alter_op_layout to clear the bad cache created before call topi compute function + """ + context = DispatchContext.current + while not isinstance(context, FallbackContext): + context = context._old_ctx + context.clear_cache(target, workload) + + +class ApplyGraphBest(DispatchContext): + """Load the graph level tuning optimal schedules. + + The input records should be in the ascending order of + node index for target operator. Usually this can be obtained + with graph tuner. + + This context maintains an internal counter to indicate the current + node index. + """ + def __init__(self, records): + """ + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + """ + from . import load_from_file + + super(ApplyGraphBest, self).__init__() + if isinstance(records, str): + records = load_from_file(records) + self._records = list(records) + self._counter = 0 + self._global_cfg_dict = {} + + def _query_inside(self, target, workload): + """ + Query the context to get config from records. + + Parameters + ---------- + target : Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : State or str + The specific state for auto scheduler. + """ + if self._counter < len(self._records): + cfg = self._records[self._counter][0].config + self._counter += 1 + self.update(target, workload, cfg) + return cfg + key = (str(target), workload) + if key not in self._global_cfg_dict: + msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \ + "A fallback configuration is used, which may bring great performance " \ + "regression." % (target, workload) + logger.warning(msg) + cfg = None + self._global_cfg_dict[key] = cfg + else: + cfg = self._global_cfg_dict[key] + return cfg + + def update(self, target, workload, cfg): + key = (str(target), workload) + self._global_cfg_dict[key] = cfg diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py new file mode 100644 index 000000000000..6d2bbd2c92af --- /dev/null +++ b/python/tvm/ansor/env.py @@ -0,0 +1,8 @@ +""" The scope to store global variables in auto_scheduelr """ + +class AutoschedulerGlobalScope(object): + def __init__(self): + self.topi_in_compute_rewrite_mode = False + +GLOBAL_SCOPE = AutoschedulerGlobalScope() + diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py new file mode 100644 index 000000000000..7d7e18a94ddf --- /dev/null +++ b/python/tvm/ansor/relay_integration.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-variable,invalid-name +""" +Decorator and utilities for the integration with TOPI and Relay +99.9% copy-paste of implementation by @MerryMercy + +""" +import threading +import warnings +import tvm + + +from .topi_integration import TaskExtractEnv +from .dispatcher import BlockingEmptyContext +from .env import GLOBAL_SCOPE + +def _lower(mod, + target, + params): + """ Helper to lower VTA properly. + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + from tvm.relay.backend import graph_runtime_codegen + + if hasattr(target, 'device_name') and target.device_name == "vta": + import vta + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(mod["main"]) + return + + # default case + # Try graph codegen first to extract autotvm tasks. + # If failed to compile, then fallback to use VM compiler. + # TODO: Currently VM compiler is likely to stack overflow for large models. + try: + with relay.build_config(opt_level=3): + opt_mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(opt_mod["main"]) + except tvm.TVMError: + compiler = relay.vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target=target) + +OP_TO_SCHEDULE = {} + +def init_op_to_schedule_map(): + # init the global map OP_TO_SCHEDULE inside a function, this is used to resolve import issues + global OP_TO_SCHEDULE + from tvm import relay + import topi + + if OP_TO_SCHEDULE: + return + + OP_TO_SCHEDULE = { + relay.op.nn.conv2d: [topi.generic.schedule_conv2d_nchw, + topi.generic.schedule_conv2d_nhwc, + topi.generic.schedule_depthwise_conv2d_nchw, + topi.generic.schedule_depthwise_conv2d_nhwc, + topi.generic.schedule_group_conv2d_nchw, + topi.generic.schedule_conv2d_winograd_without_weight_transform], + relay.op.nn.conv2d_transpose: [topi.generic.schedule_conv2d_transpose_nchw], + relay.op.nn.dense: [topi.generic.schedule_dense], + relay.op.nn.softmax: [topi.generic.schedule_softmax], + relay.op.nn.max_pool2d: [topi.generic.schedule_pool], + relay.op.nn.avg_pool2d: [topi.generic.schedule_pool], + relay.op.nn.global_avg_pool2d: [topi.generic.schedule_adaptive_pool], + relay.op.nn.global_max_pool2d: [topi.generic.schedule_adaptive_pool], + relay.op.nn.deformable_conv2d: [topi.generic.schedule_deformable_conv2d_nchw], + relay.op.mean: [topi.generic.schedule_reduce], + relay.op.prod: [topi.generic.schedule_reduce], + relay.op.nn.conv3d: [topi.generic.schedule_conv3d_ncdhw, + topi.generic.schedule_conv3d_ndhwc], + relay.op.nn.adaptive_avg_pool3d: [topi.generic.schedule_adaptive_pool], + relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul], + } + +def extract_from_program(mod, params, ops, target, target_host=None): + """ Extract tuning tasks from a relay program. + + This function is the single program version of extract_from_multiple_program. + + Parameters + ---------- + mod : relay.Module + The module to extract. + params: dict of str to numpy array + The associated parameters of the program + ops: List of relay op + List of relay ops to be tuned + target: tvm.target.Target + The compilation target + target_host: tvm.target.Target + The host compilation target + + Returns + ------- + workloads: Array of Tuple(wkl_key, target) + """ + return extract_from_multiple_program([mod], [params], ops, target, target_host) + +def extract_from_multiple_program(mods, params, ops, target, target_host=None): + """ Extract tuning tasks from multiple relay programs. + + This function collects tuning tasks by building a list of programs + with a "tracing" target and tracing all the calls to topi. + + Parameters + ---------- + mods : List of relay.Module + The modules to extract. + params: List of dict of str to numpy array + The associated parameters of the programs + ops: List of relay op + List of relay ops to be tuned + target: tvm.target.Target + The compilation target + target_host: tvm.target.Target + The host compilation target + + Returns + ------- + workloads: Array of Tuple(wkl_key, target) + """ + from tvm import relay + + env = TaskExtractEnv.get() + + init_op_to_schedule_map() + topi_scheds = [] + for op_name in ops: + if op_name in OP_TO_SCHEDULE: + topi_scheds.extend(OP_TO_SCHEDULE[op_name]) + else: + warnings.warn("Op %s is not tunable, ignored." % op_name) + + # run compiler to collect all TOPI calls during compilation + env.reset(topi_scheds) + with env: + for mod, param in zip(mods, params): + # wrap build call in thread to avoid multiprocessing problems + with BlockingEmptyContext(): + build_thread = threading.Thread(target=_lower, + args=(mod, target, param)) + build_thread.start() + build_thread.join() + relay.backend.compile_engine.get().clear() + + # create tasks for target + wkl_keys = [] + wkl_weights = [] + for wkl_key, wkl_weight in env.get_wkl_keys().items(): + wkl_keys.append(wkl_key) + wkl_weights.append(wkl_weight) + + return wkl_keys, wkl_weights + +def prepare_layout_rewrite(mod, params, ops, target): + """Prepare for kernel layout rewrite. This function will write layout infos to a global static variable, + then these layout info will be used by a relay pass `kernel_layout_transform`. + """ + from .. import relay + + env = TaskExtractEnv.get(do_layout_rewrite=True) + + init_op_to_schedule_map() + topi_scheds = [] + for op_name in ops: + if op_name in OP_TO_SCHEDULE: + topi_scheds.extend(OP_TO_SCHEDULE[op_name]) + else: + warnings.warn("Op %s is not tunable, ignored." % op_name) + + with env: + env.reset(topi_scheds) + + # wrap build call in thread to avoid multiprocessing problems + build_thread = threading.Thread(target=_lower, + args=(mod, target, param)) + build_thread.start() + build_thread.join() + relay.backend.compile_engine.get().clear() + + if env.layout_rewrite_success_ct > 0: + GLOBAL_SCOPE.topi_in_compute_rewrite_mode = True + +def finish_layout_rewrite(): + """Clear the global flag for layout rewrite""" + GLOBAL_SCOPE.topi_in_compute_rewrite_mode = False diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 387825034a09..3d7ed7733a78 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -63,6 +63,9 @@ def __iter__(self): break yield ret[0], ret[1] # (input, result) +def load_from_file(filename: str): + return zip(*LogReader(filename).read_lines()) + def write_measure_records_to_file(filename, inputs, results): """Write(append) measure records to file""" diff --git a/python/tvm/ansor/topi_integration.py b/python/tvm/ansor/topi_integration.py new file mode 100644 index 000000000000..b4c15f74ea44 --- /dev/null +++ b/python/tvm/ansor/topi_integration.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-variable,invalid-name,unused-argument +""" +Decorators for registering tunable templates to TOPI. + +These decorators can make your simple implementation be able to use different configurations +for different workloads. +Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments +(except tvm.te.Tensor) in you calls are hashable. For tvm.te.Tensor, +we will serialize it to a hashable tuple. + +See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. +""" +import tvm.te._ffi_api +from tvm import target as _target +from tvm.te import tensor +from tvm.te.tensor import PlaceholderOp, ComputeOp + +from .dispatcher import DispatchContext +from .workload_registry import register_auto_scheduler_workload_bufs, \ + make_workload_key_bufs, compute_dag_hash + +def traverse_to_get_io_tensors(outs): + layout_free_ops = [] + inputs = [] + + visited = set() + + def traverse(t): + if t in visited: + return + if isinstance(t.op, PlaceholderOp): + inputs.append(t) + elif isinstance(t.op, ComputeOp): + if "layout_free_placeholders" in t.op.attrs: + layout_free_ops.append(t.op) + for x in t.op.input_tensors: + traverse(x) + visited.add(t) + + for t in outs: + traverse(t) + + has_layout_free = (len(layout_free_ops) > 0) + return inputs + [t for t in outs], has_layout_free + +# Task extractor for relay program +class TaskExtractEnv: + """Global environment for extracting tuning tasks from graph""" + current = None + registered = None + + def __init__(self, do_layout_rewrite=False): + self.do_layout_rewrite = do_layout_rewrite + self.wanted_relay_ops = None + self.modified_funcs = [] + self.tracing = False + self.relay_disable_build_cache_ = "false" + self.layout_rewrite_success_ct = 0 + self.wkl_key_collection = {} + + def __enter__(self): + self.tracing = True + self.wkl_key_collection = {} + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.tracing = False + + def reset(self, wanted_relay_ops=None): + """Reset task collections + + Parameters + ---------- + wanted_relay_ops: List of tvm.ir.Op + The relay ops to be extracted + """ + self.wanted_relay_ops = wanted_relay_ops + self.relay_disable_build_cache_ = "false" + self.layout_rewrite_success_ct = 0 + self.wkl_key_collection = {} + + def add_task(self, key): + """Add AutoTVM task + + Parameters + ---------- + task_name: str + AutoTVM task name. + + args: tuple + Arguments to the TOPI function. + """ + if key in self.wkl_key_collection: + self.wkl_key_collection[key] += 1 + else: + self.wkl_key_collection[key] = 1 + + def get_tasks(self): + """Get collected tasks + + Returns + ------- + tasks: List of tuple(name, args) + A list of tasks extracted from the graph + """ + return self.wkl_key_collection + + def get_wkl_keys(self): + """Get collected tasks + + Returns + ------- + wkl_keys: List of autoschedule workload_key + """ + return self.wkl_key_collection + + @staticmethod + def get(do_layout_rewrite=False): + """Get the single instance of TaskExtractEnv + + Parameters + ---------- + + Returns + ------- + env: TaskExtractEnv + The single instance of TaskExtractEnv + """ + if not TaskExtractEnv.current: + TaskExtractEnv.current = TaskExtractEnv() + else: + TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite + return TaskExtractEnv.current + +def register_topi_schedule(func=None): + """Register a tunable template for a topi schedule function. + + The registration will wrap this topi schedule to take `cfg` as the first argument, + followed by the original argument list. + + Note that this function will try to find "workload" from all the ComputeOp in the input. + You can attach "workload" to your compute op by using :any:`register_topi_compute`. + + The task name has to be the same as that of the corresponding topi compute function. + + Parameters + ---------- + task_name: str + The AutoTVM task name + + func: None or callable + If it is None, return a decorator. + If is callable, decorate this function. + + Returns + ------- + decorator: callable + A decorator + + Examples + -------- + See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. + """ + def _decorate(topi_schedule): + def wrapper(outs, *args, **kwargs): + io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) + key = register_auto_scheduler_workload_bufs(io_tensors) + task_env = TaskExtractEnv.current + if task_env is not None and task_env.tracing: + if task_env.do_layout_rewrite and has_layout_free: + # Rewrite the dag and update the transform history for + # the new dag in DispatchContext + dispatch_ctx = DispatchContext.current + tgt = _target.current_target() + state = dispatch_ctx.query(tgt, key) + dag = ComputeDAG(outs) + new_dag = dag.rewrite_layout_from_state(state) + new_key = json.dumps((compute_dag_hash(new_dag),)) + dispatch_ctx.update(tgt, new_key, state) + + if new_key != key: + task_env.layout_rewrite_success_ct += 1 + + # Call schedule_func under FallbackContext() to avoid layout rewrite + tgt = _target.Target.current() + cfg = BlockingEmptyContext().query(tgt, key) + return topi_schedule(cfg, outs) + + task_env.add_task(key) + + """wrapper function for topi schedule""" + tgt = _target.Target.current() + cfg = DispatchContext.current.query(tgt, key) + return topi_schedule(cfg, outs) + return wrapper + if func: + return _decorate(func) + return _decorate diff --git a/scripts/tune_network.py b/scripts/tune_network.py new file mode 100644 index 000000000000..3d858ce60ab0 --- /dev/null +++ b/scripts/tune_network.py @@ -0,0 +1,497 @@ +"""Tune all workloads in a network""" +import argparse +import logging +import random +import os +import time +import numpy as np + +import tvm +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server +from tvm import ansor as auto_scheduler +from tvm import relay +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server +from tvm.relay import testing +#from tvm._ffi.function import get_global_func +import tvm.contrib.graph_runtime as runtime +from tvm.contrib.debugger import debug_runtime +from tvm.contrib import util, ndk +from common import str2bool +from tvm.ansor import LocalRunner, LogToFile, TuneOption, SimpleTaskScheduler, \ + RPCRunner, LocalBuilder +from tvm.ansor.utils import request_remote +#from baseline.utils import log_line, BenchmarkRecord + +dtype = "float32" + +def get_network(name, model_path, batch_size, layout): + """Get the symbol definition and random weight of a network""" + input_shape = (batch_size, 3, 224, 224) + output_shape = (batch_size, 1000) + input_name = 'data' + + if name.startswith("resnet3d"): + n_layer = int(name.split('-')[1]) + layout = "NDHWC" + image_shape = (16, 112, 112, 3) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.resnet3d.get_workload(num_layers=n_layer, batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) + elif name.startswith("resnet"): + n_layer = int(name.split('-')[1]) + image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) + print(mod) + elif "lstm" in name: + mod, params = relay.testing.lstm.get_workload(iterations=10, num_hidden=512, batch_size=batch_size, dtype=dtype) + elif "mlp" in name: + input_shape = (batch_size, 1, 28, 28) + mod, params = relay.testing.mlp.get_workload(batch_size=batch_size, dtype=dtype) + elif "vgg" in name: + n_layer = int(name.split('-')[1]) + mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) + elif name == 'dcgan': + input_shape = (batch_size, 100) + mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size, layout=layout) + elif name == 'dqn': + image_shape = (84, 84, 4) if layout == 'NHWC' else (4, 84, 84) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype) + elif name == 'mobilenet': + image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) + elif name == 'r3d_18': + import torch + import torchvision + + model = getattr(torchvision.models.video, name)(pretrained=False) + model = model.eval() + + # We grab the TorchScripted model via tracing + input_shape = [batch_size, 3, 16, 112, 112] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(model, input_data).eval() + + input_name = 'input0' # only one input, set it to this name + shape_list = {input_name: input_shape} + mod, params = relay.frontend.from_pytorch(scripted_model, + shape_list) + elif name == 'squeezenet_v1.1': + mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) + elif name == 'inception_v3': + input_shape = (batch_size, 3, 299, 299) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == 'mxnet': + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + block = get_model('resnet18_v1', pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"input_name": input_shape}, dtype=dtype) + net = mod["main"] + net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) + mod = relay.Module.from_expr(net) + elif name == 'tflite-mobilenet-v2' or name == 'tflite-resnet-v2-50': + try: + import tflite.Model + except ImportError: + raise ImportError("The tflite package must be installed") + input_name = "input" + input_shape = (1, 224, 224, 3) + output_shape = (1, 1001) + input_dtype = "float32" + tflite_model_buf = open(model_path, "rb").read() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict={input_name: input_shape}, + dtype_dict={input_name: input_dtype}) + elif name == 'pytorch-mobilenet-v2': + import torch + + model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', pretrained=False) + model.eval() + + input_shape = [batch_size, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(model, input_data).eval() + + input_name = 'input0' + shape_list = {input_name: input_shape} + mod, params = relay.frontend.from_pytorch(scripted_model, + shape_list) + elif name == 'bert': + import tensorflow as tf + + bert_pb = './baseline/tensorflow/tf_models/bert/bert-B%d.pb' % batch_size + try: + with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + except: + raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") + + input_shape = (batch_size, 128) + input_name = ['input'] + shape_dict = { + 'input': input_shape + } + out_names = [ + 'bert/pooler/dense/Tanh' + ] + + mod, params = relay.frontend.from_tensorflow(graph_def, + shape=shape_dict, + outputs=out_names) + elif name == 'tflite-textcnn': + try: + import tflite.Model + except ImportError: + raise ImportError("The tflite package must be installed") + model_path = './baseline/tensorflow/fake_textcnn.tflite' + input_name = "Placeholder" + input_shape = (batch_size, 200, 128, 1) + output_shape = (1, 1001) + input_dtype = "float32" + tflite_model_buf = open(model_path, "rb").read() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict={input_name: input_shape}, + dtype_dict={input_name: input_dtype}) + print(mod['main']) + elif name == 'textcnn': + import tensorflow as tf + + bert_pb = './baseline/tensorflow/fake_textcnn.pb' + try: + with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + except: + raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") + + input_shape = (batch_size, 200, 128, 1) + input_name = ['Placeholder'] + shape_dict = { + 'Placeholder': input_shape + } + out_names = [ + 'concat/concat_dim' + ] + + mod, params = relay.frontend.from_tensorflow(graph_def, + shape=shape_dict, + outputs=out_names) + print(mod['main']) + elif name == 'tdnn': + import tensorflow as tf + + pb = './baseline/tensorflow/pruned_model_0407.pb' + #pb = './baseline/tensorflow/tdnn_4001.pb' + try: + with tf.compat.v1.gfile.GFile(pb, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + except: + raise ValueError("Need to run ./baseline/tensorflow/bert_convert.py to get model first") + + input_shape = (batch_size, 600, 64) + input_name = ['tf_loss_fn/Placeholder'] + + shape_dict = { + 'tf_loss_fn/Placeholder': input_shape, + } + out_names = [ + #"tf_loss_fn/ForwardPass/w2l_encoder/conv91/Conv2D" + "tf_loss_fn/ForwardPass/Softmax" + ] + mod, params = relay.frontend.from_tensorflow(graph_def, + shape=shape_dict, + outputs=out_names) + else: + raise ValueError("Unsupported network: " + name) + + return mod, params, input_name, input_shape, output_shape + + +def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, + local_measure, ndk_cc, device_key, host, port, run_timeout, num_threads, seed=43): + # Upload parameters to device + if local_measure: + if target.target_name == "cuda": + ctx = tvm.gpu() + else: + ctx = tvm.cpu() + if num_threads: + config_threadpool = get_global_func('runtime.config_threadpool') + config_threadpool(0, num_threads) + else: + print("=============== Request Remote ===============") + if 'TVM_NDK_CC' not in os.environ: + os.environ['TVM_NDK_CC'] = ndk_cc + remote = request_remote(device_key, host, port, timeout=run_timeout) + + print("=============== Export ===============") + ctx = remote.cpu() + temp = util.tempdir() + path_lib = temp.relpath("deploy_lib.so") + lib.export_library(path_lib, ndk.create_shared) + + print("=============== Upload ===============") + remote.upload(path_lib) + + print("=============== Load ===============") + lib = remote.load_module("deploy_lib.so") + if num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, num_threads) + + np.random.seed(seed) + data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) + if debug_profile: + module = debug_runtime.create(graph, lib, ctx) + else: + module = runtime.create(graph, lib, ctx) + if type(input_name) == list: + for name in input_name: + module.set_input(name, data_tvm) + else: + module.set_input(input_name, data_tvm) + for k, v in params.items(): + module.set_input(k, v) + + return module, ctx + + +def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, + local_measure, device_key, host, port, n_parallel, ndk_cc, + build_timeout, run_timeout, num_threads, tune, check_correctness, + debug_profile, tuning_parameters, record_file, layout_set): + joint_tuner, model_type, policy, log_file, load_log_file = (tuning_parameters['joint_tuner'], + tuning_parameters['model_type'], tuning_parameters['policy'], + tuning_parameters['log_file'], tuning_parameters['load_log_file']) + + if layout_set: + layout = layout_set + elif target.target_name == 'cuda': + layout = 'NCHW' + else: + layout = "NHWC" + + # Extract workloads from relay program + print("=============== Extract workloads ===============") + mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout) + + if tune: + workloads, wkl_weights = auto_scheduler.extract_from_program(mod, target=target, + params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, + relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, + relay.op.nn.max_pool2d, relay.op.nn.avg_pool2d, + relay.op.nn.global_max_pool2d, relay.op.nn.global_avg_pool2d, + relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d, + relay.op.nn.batch_matmul, relay.op.mean, + )) + print("Total workload number: %d" % (len(workloads))) + #workloads = workloads[1:2] + #wkl_weights = wkl_weights[1:2] + #workloads = ['["2543426b0070d4a379a1f75a362a5f1b"]'] + + + # Tune workloads with auto scheduler + print("=============== Tuning ===============") + tasks = [] + for i, wkl_key in enumerate(workloads): + dag = auto_scheduler.workload_key_to_dag(wkl_key) + print("[========= Task %d =========]\n" % i, dag) + tasks.append(auto_scheduler.SearchTask(dag, wkl_key, target, target_host)) + + if joint_tuner != 'rl': + tuner = SimpleTaskScheduler(tasks, load_log_file=load_log_file) + elif joint_tuner == 'rl': + # put import here to remove pytorch dependency + from tvm.auto_scheduler.joint_tuner.rl_joint_tuner import RLJointTuner + tuner = RLJointTuner(tasks, weights=wkl_weights, load_log_file=load_log_file) + else: + raise ValueError("Invalid joint tuner: " + joint_tuner) + + if local_measure: + builder = LocalBuilder(timeout=build_timeout) + if target.target_name == "cuda": + ctx = tvm.context("cuda", 0) + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + + tracker = Tracker('0.0.0.0', port=port, port_end=10000, silent=True) + if device_key is None: + device_key = '$local$device$%d' % tracker.port + server = Server('0.0.0.0', port=tracker.port, port_end=10000, + key=device_key, use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + runner = RPCRunner(device_key, host=host, port=tracker.port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel) + else: + os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" + runner = LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = LocalBuilder(build_func='ndk', timeout=build_timeout) + runner = RPCRunner(device_key, host=host, port=port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel, timeout=run_timeout) + + search_policy = "%s.%s" % (policy, model_type) + tune_option = TuneOption(n_trials=tuning_parameters['n_trials'], + early_stopping=tuning_parameters['early_stopping'], + num_measure_per_iter=tuning_parameters['num_measure_per_iter'], + builder=builder, + verbose=tuning_parameters['verbose'], + runner=runner, + measure_callbacks=[LogToFile(log_file)]) + if local_measure and target.target_name != 'cuda': + os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + tuner.tune(tune_option, search_policy) + else: + tuner.tune(tune_option, search_policy) + + kernel_layout_rewrite = False + + # Compile graph with best states found by auto-scheduler + print("=============== Compile ===============") + with auto_scheduler.apply_history_best(log_file, args.log_n_lines): + #if True: + #with auto_scheduler.BlockingEmptyContext(): + os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" + os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + if kernel_layout_rewrite: + auto_scheduler.prepare_layout_rewrite(mod, target=target, + params=params, + ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) + else: + # disable layout rewrite + auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + + with relay.build_config(opt_level=3): + graph, lib, opt_params = relay.build_module.build( + mod, target=target, params=params) + ''' + from tvm.relay.backend import graph_runtime_codegen + with relay.build_config(opt_level=3): + opt_mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(opt_mod["main"]) + with tvm.transform.PassContext(opt_level=3): + graph, lib, opt_params = relay.build_module.build( + mod, target=target, params=params) + ''' + auto_scheduler.finish_layout_rewrite() + print("=============== Compile Finish ===============") + + module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params, + debug_profile, local_measure, ndk_cc, + device_key, host, port, run_timeout, num_threads) + + # Evaluate + print("========== Evaluate ==========") + ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) + prof_res = np.array(ftimer().results) + # display profile information + if debug_profile or check_correctness: + module.run() + if check_correctness: + actual_output = module.get_output(0).asnumpy() + print(actual_output) + + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res) * 1000, np.std(prof_res) * 1000)) + #log_line(BenchmarkRecord(target.target_name, 'gpu' if target.target_name == 'cuda' else 'cpu', 'network', + # "%s.B%d" % (network_name, batch_size), 'AutoSchedule', layout, + # {"costs": prof_res}, time.time()), record_file) + + if check_correctness: + print("========== Check Correctness ==========") + # clean relay cache + relay.backend.compile_engine.get().clear() + + # disable layout rewrite + auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + target = tvm.target.create('llvm') + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, opt_params = relay.build_module.build( + mod, target=target, params=params) + + module, _ = create_module(data_shape, graph, lib, target, input_name, opt_params, + debug_profile, local_measure, ndk_cc, + device_key, host, port, run_timeout, num_threads) + module.run() + + expected_output = module.get_output(0).asnumpy() + np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--network", type=str, required=True) + parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') + parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], + default='meta-rewrite') + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=10) + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') + parser.add_argument("--load-model", action='store_true') + parser.add_argument("--model-file", type=str, default='saved_model.xgb') + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + parser.add_argument("--out-file", type=str, default='results.tsv') + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--joint-tuner", type=str, default='bottleneck-decay', help='The type of joint tuner', + choices=['no', 'uniform', 'weighted', 'bottleneck', 'bottleneck-decay', 'sequential', 'round-robin', 'rl']) + parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--device-key", type=str, default=None) + parser.add_argument("--host", type=str, default='0.0.0.0') + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--n-parallel", type=int, default=1) + parser.add_argument("--ndk-cc", type=str, default=None) + parser.add_argument("--num-threads", type=int, default=None) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + parser.add_argument("--layout", type=str, default=None) + parser.add_argument("--log-n-lines", type=int) + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + logging.basicConfig() + logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + + target = tvm.target.create(args.target) + + tuning_parameters = { + 'n_trials': args.n_trials, + 'num_measure_per_iter': args.num_measure_per_iter, + 'log_file': args.log_file if args.log_file else "%s-B%d.json" % (args.network, args.batch_size), + 'model_type': args.model_type, + 'joint_tuner': args.joint_tuner, + 'policy': args.policy, + 'early_stopping': -1, + 'verbose': 1, + } + tuning_parameters['load_log_file'] = args.load_log or tuning_parameters['log_file'] + + os.environ["TOPHUB_LOCATION"] = "NONE" + tune_and_evaluate(args.network, args.model_path, args.batch_size, target, args.target_host, + args.local_measure, args.device_key, args.host, + args.port, args.n_parallel, args.ndk_cc, args.build_timeout, + args.run_timeout, args.num_threads, args.tune, args.check_correctness, + args.debug_profile, tuning_parameters, args.out_file, args.layout) diff --git a/topi/python/topi/ansor.py b/topi/python/topi/ansor.py new file mode 100644 index 000000000000..e821fd5bd42f --- /dev/null +++ b/topi/python/topi/ansor.py @@ -0,0 +1,95 @@ +"""All AutoSchedule Supported Operators""" +from __future__ import absolute_import as _abs +from tvm import ansor + +@ansor.register_topi_schedule() +def schedule_dense_nopack(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_nhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_NCHWc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_reduce(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_pool(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_adaptive_pool(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_softmax(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_nchw_int8(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_depthwise_conv2d_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_depthwise_conv2d_nhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_NCHWc_int8(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_depthwise_conv2d_NCHWc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_transpose_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv3d_ncdhw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv3d_ndhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv1d_ncw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv1d_nwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_dense_pack(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_batch_matmul(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_bitserial_conv2d_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_bitserial_conv2d_nhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_bitserial_dense(cfg, outs): + return ansor.gen_schedule(cfg, outs) diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index e121fbc7ec6d..e6ccadd4755f 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -26,3 +26,8 @@ from .bitserial_dense import * from .injective import * from . import cortex_m7 + +import os +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +if use_auto_scheduler.lower() == "true": + from ..ansor import * diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 6171317cd80f..7f37ba78a06c 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -39,3 +39,8 @@ from .sort import * from .search import * from .image import * + +import os +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +if use_auto_scheduler.lower() == "true": + from ..ansor import * diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 659668cbbe4c..a334397249e3 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -39,3 +39,8 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * + +import os +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +if use_auto_scheduler.lower() == "true": + from ..ansor import * From 674027f8d6b9943508ee9eaf0fba703189a1c781 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 12 Jun 2020 18:16:28 +0800 Subject: [PATCH 23/78] Add tune_op_subgraph.py & Some code clean for tune_network.py (#23) * Add single op tune scripts * Add tune subgraph support * Merge all op & all subgraph to one file * Rename file --- python/tvm/ansor/auto_schedule.py | 1 + python/tvm/ansor/relay_integration.py | 2 +- scripts/common.py | 2 +- scripts/shape_configs.py | 248 ++++++++++ scripts/tune_network.py | 230 +++------ scripts/tune_op_subgraph.py | 599 +++++++++++++++++++++++ scripts/tune_test.py | 79 +-- src/ansor/search_policy/search_policy.cc | 6 +- 8 files changed, 968 insertions(+), 199 deletions(-) create mode 100644 scripts/shape_configs.py create mode 100644 scripts/tune_op_subgraph.py diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index e1a0711a80be..09895302d25a 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -193,6 +193,7 @@ class TuneOption(Object): Callback functions called before the search process Candidates: - ansor.PreLoadMeasuredStates + - ansor.PreAddCustomRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 7d7e18a94ddf..de2e12e389e7 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -196,7 +196,7 @@ def prepare_layout_rewrite(mod, params, ops, target): # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, - args=(mod, target, param)) + args=(mod, target, params)) build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() diff --git a/scripts/common.py b/scripts/common.py index 4400104bdfe6..84fbf8d6c731 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -168,7 +168,7 @@ def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_top, pad_left, pad_down, pad_right = topi.nn.util.get_pad_tuple( + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_channel = num_filter out_height = topi.util.simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py new file mode 100644 index 000000000000..95a1ba69634d --- /dev/null +++ b/scripts/shape_configs.py @@ -0,0 +1,248 @@ +""" Shape configurations for single operator evaluation +This file is shared by tune_all_single_op.py and scripts in baseline/ +""" + +matmul_shapes = [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), +] + +conv1d_shapes = [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), +# (1, 256, 64, 128, 1, 2, 0), +# (1, 256, 64, 64, 1, 1, 0), +# (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), +# (1, 128, 128, 128, 3, 1, 1), +# (1, 64, 256, 512, 3, 2, 1), +# (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), +] + +conv2d_shapes = [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), +# (1, 56, 56, 64, 128, 3, 2, 1), +# (1, 56, 56, 64, 128, 1, 2, 0), +# (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), +# (1, 28, 28, 128, 256, 3, 2, 1), +# (1, 28, 28, 128, 256, 1, 2, 0), +# (1, 28, 28, 128, 128, 3, 1, 1), +# (1, 14, 14, 256, 512, 3, 2, 1), +# (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), +] + +conv3d_shapes = [ + # Derived from cnov2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), +# (1, 16, 56, 56, 64, 128, 3, 2, 1), +# (1, 16, 56, 56, 64, 128, 1, 2, 0), +# (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), +# (1, 16, 28, 28, 128, 256, 3, 2, 1), +# (1, 16, 28, 28, 128, 256, 1, 2, 0), +# (1, 16, 28, 28, 128, 128, 3, 1, 1), +# (1, 16, 14, 14, 256, 512, 3, 2, 1), +# (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), +] + +group_conv2d_shapes = [ + # Derived from cnov2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1 , 1, 4), +# (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), +# (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0 , 1, 4), +# (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), +# (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), +# (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), +# (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), +# (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1 , 1, 4), +] + +dilation_conv2d_shapes = [ + # Derived from cnov2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3 , 2), +# (1, 56, 56, 64, 128, 3, 2, 1 , 2), +# (1, 56, 56, 64, 128, 1, 2, 0 , 2), +# (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0 , 2), +# (1, 28, 28, 128, 256, 3, 2, 1, 2), +# (1, 28, 28, 128, 256, 1, 2, 0, 2), +# (1, 28, 28, 128, 128, 3, 1, 1, 2), +# (1, 14, 14, 256, 512, 3, 2, 1, 2), +# (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1 , 2), +] + +depthwise_conv2d_shapes = [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), +# (1, 56, 56, 128, 3, 1, 1), +# (1, 56, 56, 128, 3, 2, 1), +# (1, 28, 28, 256, 3, 1, 1), +# (1, 28, 28, 256, 3, 2, 1), +# (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), +] + +conv2d_transpose_shapes = [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), +] + +conv2d_capsule_shapes = [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), +] + +conv2d_winograd_nhwc_shapes = [ + (1, 56, 56, 64, 64, 3, 1, 1), + (1, 28, 28, 128, 128, 3, 1, 1), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), +] + +conv2d_winograd_nchw_shapes = [ + (1, 64, 56, 56, 64, 3, 1, 1), + (1, 128, 28, 28, 128, 3, 1, 1), + (1, 256, 14, 14, 256, 3, 1, 1), + (1, 512, 7, 7, 512, 3, 1, 1), +] + +matmul_tensor_core_shapes = [ + (16, 512, 512, 'float16', 'float32', True), + (32, 512, 512, 'float16', 'float32', True), + (512, 512, 512, 'float16', 'float32', True), +] + +norm_shapes = [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), +] + +softmax_shapes = [ + (1, 1024), + (1, 4096), + (1, 16384), + (1, 65536), +] + +single_op_shape_dict = { + 'C1D': conv1d_shapes, + 'C2D': conv2d_shapes, + 'C3D': conv3d_shapes, + 'GMM': matmul_shapes, + 'GRP': group_conv2d_shapes, + 'DIL': dilation_conv2d_shapes, + 'DEP': depthwise_conv2d_shapes, + 'T2D': conv2d_transpose_shapes, + 'CAP': conv2d_capsule_shapes, + 'NRM': norm_shapes, + #'SMX': softmax_shapes, + +# The following workloads are not in our sinle op evaluation plan. +# They should be moved to `common.py` and be used by `tune_wkl.py`. +# 'C2D_NCHW': conv2d_nchw_shapes, + 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, +# 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, +# 'GMM_TC': matmul_tensor_core_shapes, +} + +conv2d_bn_relu_shapes = [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1, 1), + (16, 224, 224, 3, 64, 7, 2, 3), + (16, 56, 56, 64, 128, 3, 2, 1), + (16, 28, 28, 128, 256, 1, 2, 0), + (16, 7, 7, 512, 512, 3, 1, 1, 1), +] + +transpose_batch_matmul_shapes = [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + (16, 128, 12, 64), + (16, 128, 16, 64), + (16, 64, 12, 128), + (16, 128, 12, 128), +] + + +batch_norm_shapes = [ + (16, 256), + (16, 1024), + (16, 4096), + (16, 16384), + (16, 65536), +] + +subgraph_shape_dict = { + "conv2d_bn_relu": conv2d_bn_relu_shapes, + "transpose_batch_matmul": transpose_batch_matmul_shapes, + #"batch_norm": batch_norm_shapes, +} + +resnet_shapes = [ + (1, ), + (16, ), +] + +mobilenet_v2_shapes = [ + (1, ), + (16, ), +] + +dcgan_shapes = [ + (1, ), + (16, ), +] + +dqn_shapes = [ + (1, ), + (16, ), +] + +bert_shapes = [ + (1, ), + (16, ), +] + +resnet18_3d_shapes = [ + (1, ), + (16, ), +] + +network_shape_dict = { + 'resnet_50': resnet_shapes, + 'mobilenet_v2': mobilenet_v2_shapes, + 'dcgan': dcgan_shapes, + 'dqn': dqn_shapes, + 'bert': bert_shapes, + 'resnet_18_3d': resnet18_3d_shapes, +} + diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 3d858ce60ab0..f1f7cd54f8c6 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -7,23 +7,17 @@ import numpy as np import tvm -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server -from tvm import ansor as auto_scheduler -from tvm import relay -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server -from tvm.relay import testing -#from tvm._ffi.function import get_global_func +from tvm import _ffi, relay, ansor import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk -from common import str2bool -from tvm.ansor import LocalRunner, LogToFile, TuneOption, SimpleTaskScheduler, \ - RPCRunner, LocalBuilder +from tvm.relay import testing from tvm.ansor.utils import request_remote #from baseline.utils import log_line, BenchmarkRecord +from common import str2bool +from tune_test import create_tune_option + dtype = "float32" def get_network(name, model_path, batch_size, layout): @@ -43,7 +37,6 @@ def get_network(name, model_path, batch_size, layout): image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) input_shape = (batch_size, *image_shape) mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) - print(mod) elif "lstm" in name: mod, params = relay.testing.lstm.get_workload(iterations=10, num_hidden=512, batch_size=batch_size, dtype=dtype) elif "mlp" in name: @@ -54,9 +47,9 @@ def get_network(name, model_path, batch_size, layout): mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) elif name == 'dcgan': input_shape = (batch_size, 100) - mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size, layout=layout) + mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size) elif name == 'dqn': - image_shape = (84, 84, 4) if layout == 'NHWC' else (4, 84, 84) + image_shape = (4, 84, 84) input_shape = (batch_size, *image_shape) mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype) elif name == 'mobilenet': @@ -143,71 +136,6 @@ def get_network(name, model_path, batch_size, layout): mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict, outputs=out_names) - elif name == 'tflite-textcnn': - try: - import tflite.Model - except ImportError: - raise ImportError("The tflite package must be installed") - model_path = './baseline/tensorflow/fake_textcnn.tflite' - input_name = "Placeholder" - input_shape = (batch_size, 200, 128, 1) - output_shape = (1, 1001) - input_dtype = "float32" - tflite_model_buf = open(model_path, "rb").read() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - mod, params = relay.frontend.from_tflite(tflite_model, - shape_dict={input_name: input_shape}, - dtype_dict={input_name: input_dtype}) - print(mod['main']) - elif name == 'textcnn': - import tensorflow as tf - - bert_pb = './baseline/tensorflow/fake_textcnn.pb' - try: - with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - except: - raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") - - input_shape = (batch_size, 200, 128, 1) - input_name = ['Placeholder'] - shape_dict = { - 'Placeholder': input_shape - } - out_names = [ - 'concat/concat_dim' - ] - - mod, params = relay.frontend.from_tensorflow(graph_def, - shape=shape_dict, - outputs=out_names) - print(mod['main']) - elif name == 'tdnn': - import tensorflow as tf - - pb = './baseline/tensorflow/pruned_model_0407.pb' - #pb = './baseline/tensorflow/tdnn_4001.pb' - try: - with tf.compat.v1.gfile.GFile(pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - except: - raise ValueError("Need to run ./baseline/tensorflow/bert_convert.py to get model first") - - input_shape = (batch_size, 600, 64) - input_name = ['tf_loss_fn/Placeholder'] - - shape_dict = { - 'tf_loss_fn/Placeholder': input_shape, - } - out_names = [ - #"tf_loss_fn/ForwardPass/w2l_encoder/conv91/Conv2D" - "tf_loss_fn/ForwardPass/Softmax" - ] - mod, params = relay.frontend.from_tensorflow(graph_def, - shape=shape_dict, - outputs=out_names) else: raise ValueError("Unsupported network: " + name) @@ -223,7 +151,7 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof else: ctx = tvm.cpu() if num_threads: - config_threadpool = get_global_func('runtime.config_threadpool') + config_threadpool = _ffi.get_global_func('runtime.config_threadpool') config_threadpool(0, num_threads) else: print("=============== Request Remote ===============") @@ -267,23 +195,19 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, local_measure, device_key, host, port, n_parallel, ndk_cc, build_timeout, run_timeout, num_threads, tune, check_correctness, debug_profile, tuning_parameters, record_file, layout_set): - joint_tuner, model_type, policy, log_file, load_log_file = (tuning_parameters['joint_tuner'], + task_scheduler, model_type, policy, log_file, load_log_file = (tuning_parameters['task_scheduler'], tuning_parameters['model_type'], tuning_parameters['policy'], tuning_parameters['log_file'], tuning_parameters['load_log_file']) if layout_set: layout = layout_set - elif target.target_name == 'cuda': - layout = 'NCHW' - else: - layout = "NHWC" # Extract workloads from relay program print("=============== Extract workloads ===============") mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout) if tune: - workloads, wkl_weights = auto_scheduler.extract_from_program(mod, target=target, + workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d, relay.op.nn.avg_pool2d, @@ -292,85 +216,54 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, relay.op.nn.batch_matmul, relay.op.mean, )) print("Total workload number: %d" % (len(workloads))) - #workloads = workloads[1:2] - #wkl_weights = wkl_weights[1:2] - #workloads = ['["2543426b0070d4a379a1f75a362a5f1b"]'] - # Tune workloads with auto scheduler print("=============== Tuning ===============") tasks = [] for i, wkl_key in enumerate(workloads): - dag = auto_scheduler.workload_key_to_dag(wkl_key) + dag = ansor.workload_key_to_dag(wkl_key) print("[========= Task %d =========]\n" % i, dag) - tasks.append(auto_scheduler.SearchTask(dag, wkl_key, target, target_host)) - - if joint_tuner != 'rl': - tuner = SimpleTaskScheduler(tasks, load_log_file=load_log_file) - elif joint_tuner == 'rl': - # put import here to remove pytorch dependency - from tvm.auto_scheduler.joint_tuner.rl_joint_tuner import RLJointTuner - tuner = RLJointTuner(tasks, weights=wkl_weights, load_log_file=load_log_file) - else: - raise ValueError("Invalid joint tuner: " + joint_tuner) - - if local_measure: - builder = LocalBuilder(timeout=build_timeout) - if target.target_name == "cuda": - ctx = tvm.context("cuda", 0) - cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) - - tracker = Tracker('0.0.0.0', port=port, port_end=10000, silent=True) - if device_key is None: - device_key = '$local$device$%d' % tracker.port - server = Server('0.0.0.0', port=tracker.port, port_end=10000, - key=device_key, use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - runner = RPCRunner(device_key, host=host, port=tracker.port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel) - else: - os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" - runner = LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = LocalBuilder(build_func='ndk', timeout=build_timeout) - runner = RPCRunner(device_key, host=host, port=port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel, timeout=run_timeout) + tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) + + def objective_func(costs): + return sum(c * w for c, w in zip(costs, wkl_weights)) + + tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, + load_log_file=load_log_file, + load_model_file=tuning_parameters['load_model']) + tune_option, measure_ctx = create_tune_option(target, log_file, + tuning_parameters['n_trials'], tuning_parameters['num_measure_per_iter'], + tuning_parameters['verbose'], n_parallel, build_timeout, + local_measure, device_key, host, port, ndk_cc, + tuning_parameters['early_stopping']) search_policy = "%s.%s" % (policy, model_type) - tune_option = TuneOption(n_trials=tuning_parameters['n_trials'], - early_stopping=tuning_parameters['early_stopping'], - num_measure_per_iter=tuning_parameters['num_measure_per_iter'], - builder=builder, - verbose=tuning_parameters['verbose'], - runner=runner, - measure_callbacks=[LogToFile(log_file)]) + if local_measure and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" - tuner.tune(tune_option, search_policy) - else: - tuner.tune(tune_option, search_policy) + + tuner.tune(tune_option, search_policy) + + if measure_ctx: + del measure_ctx kernel_layout_rewrite = False # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with auto_scheduler.apply_history_best(log_file, args.log_n_lines): + with ansor.apply_history_best(log_file, args.log_n_lines): #if True: - #with auto_scheduler.BlockingEmptyContext(): + #with ansor.BlockingEmptyContext(): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" os.environ['TVM_BIND_MASTER_CORE_0'] = "1" if kernel_layout_rewrite: - auto_scheduler.prepare_layout_rewrite(mod, target=target, + ansor.prepare_layout_rewrite(mod, target=target, params=params, ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) else: # disable layout rewrite - auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE - auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE with relay.build_config(opt_level=3): graph, lib, opt_params = relay.build_module.build( @@ -385,7 +278,7 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) ''' - auto_scheduler.finish_layout_rewrite() + ansor.finish_layout_rewrite() print("=============== Compile Finish ===============") module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params, @@ -415,8 +308,8 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, relay.backend.compile_engine.get().clear() # disable layout rewrite - auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE - auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE target = tvm.target.create('llvm') with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( @@ -433,28 +326,40 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, if __name__ == "__main__": parser = argparse.ArgumentParser() + # Task related options parser.add_argument("--network", type=str, required=True) parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") - parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--layout", type=str, default='NHWC') parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=10) - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + + # Strategy related options + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], + default='meta-rewrite') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--load-model", action='store_true') - parser.add_argument("--model-file", type=str, default='saved_model.xgb') + parser.add_argument("--task-scheduler", type=str, default='gradient', + choices=['no', 'gradient', 'round-robin'], + help='The strategy of task scheduler') + + # File related options + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") parser.add_argument("--out-file", type=str, default='results.tsv') - parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--log-n-lines", type=int) + + # Detailed control options + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=10) parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--joint-tuner", type=str, default='bottleneck-decay', help='The type of joint tuner', - choices=['no', 'uniform', 'weighted', 'bottleneck', 'bottleneck-decay', 'sequential', 'round-robin', 'rl']) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) parser.add_argument("--host", type=str, default='0.0.0.0') @@ -462,27 +367,22 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) parser.add_argument("--num-threads", type=int, default=None) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--layout", type=str, default=None) - parser.add_argument("--log-n-lines", type=int) args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) - logging.basicConfig() - logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + logging.getLogger('ansor').setLevel(logging.DEBUG) target = tvm.target.create(args.target) tuning_parameters = { 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, - 'log_file': args.log_file if args.log_file else "%s-B%d.json" % (args.network, args.batch_size), + 'log_file': args.log_file or "%s-B%d.json" % (args.network, args.batch_size), + 'load_model': args.load_model, 'model_type': args.model_type, - 'joint_tuner': args.joint_tuner, + 'task_scheduler': args.task_scheduler, 'policy': args.policy, 'early_stopping': -1, 'verbose': 1, diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py new file mode 100644 index 000000000000..bf5cbe83c952 --- /dev/null +++ b/scripts/tune_op_subgraph.py @@ -0,0 +1,599 @@ +"""Tune all operators for single op & subgraph evaluation""" +import argparse +import logging +import os +import random + +import numpy as np + +import tvm +from tvm import te, ansor +import topi +from topi.nn.winograd_util import winograd_transform_matrices +from topi.util import get_const_tuple + +from common import measure_schedule, str2bool, \ + norm_bmn, softmax_mn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu +from shape_configs import single_op_shape_dict, subgraph_shape_dict +from tune_test import tune_workloads_jointly, replay_workload, create_tune_option + +# ========================== Single Ops ========================== + +@ansor.register_auto_scheduler_workload_func +def batch_matmul_nkkm(B, N, M, K): + X = te.placeholder((B, N, K), name='A') + Y = te.placeholder((B, K, M), name='B') + k = te.reduce_axis((0, K), name='k') + Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') + return [X, Y, Z] + +@ansor.register_auto_scheduler_workload_func +def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, L, CI), name='inputs') + weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') + + batch_size, in_len, in_channel = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name='rc') + rl = te.reduce_axis((0, k_len), name='rl') + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + (padded[n, l * stride + rl * dilation, co // out_channel_per_group * channel_per_group + rc] * + weight[rl, rc, co]), axis=[rl, rc]), + name='conv1d_nlc' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, H, W, CI), name='inputs') + weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') + batch_size, in_h, in_w, in_channel = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc] + * weight[rh, rw, rc, co]), axis=[rh, rw, rc] + ), + name='conv2d_nhwc' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, CI, H, W), name='inputs') + weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') + batch_size, in_channel, in_h, in_w = inputs.shape + out_channel, channel_per_group, k_h, k_w, = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + + padded = topi.nn.pad(inputs, [0, 0, padding, padding]) + output = te.compute( + (batch_size, out_channel, out_h, out_w), + lambda n, co, h, w: te.sum( + (padded[n, co // out_channel_per_group * channel_per_group + rc, + h * stride + rh * dilation, w * stride + rw * dilation] + * weight[co, rc, rh, rw]), axis=[rc, rh, rw] + ), + name='conv2d_nchw' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, D, H, W, CI)) + weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) + batch_size, in_d, in_h, in_w, in_channel = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name='rd') + rh = te.reduce_axis((0, k_h), name='rh') + rw = te.reduce_axis((0, k_w), name='rw') + rc = te.reduce_axis((0, channel_per_group), name='rc') + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + (padded[n, d * stride + rd * dilation, + h * stride + rh * dilation, w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc] + * weight[rd, rh, rw, rc, co]), + axis=[rd, rh, rw, rc] + ), + name='conv3d_ndhwc' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + + assert factor.value == 1, "Not optimized for factor != 1" + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name='rh') + rw = te.reduce_axis((0, k_w), name='rw') + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, c // factor] + * weight[c % factor, rh, rw, c // factor]), + axis=[rh, rw] + ), + name="depth_conv2d_nhwc" + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): + inputs = te.placeholder((N, H, W, CI), name='inputs') + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(padding, (filter_h, filter_w)) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad(inputs, + [0, (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, 0], + [0, (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, 0]) + + # remove extra padding introduced by dilatation + idxdiv = te.indexdiv + idxmod = te.indexmod + border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) + border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idxdiv(indices[i], strides[i])) + not_zero.append(idxmod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tvm.tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name='rc') + rh = te.reduce_axis((0, filter_h), name='rh') + rw = te.reduce_axis((0, filter_w), name='rw') + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) * + weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc]), + name="conv2d_transpose_nhwc", + attrs={"auto_scheduler_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) + # todo(lmzheng): add constraints on the tile size of h and w + + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') + weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name='cap_k') + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + (padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co]), axis=[rh, rw, cap_k, rc] + ), + name='conv2d_capsule_nhwijc' + ) + return [inputs, weight, output] + + +@ansor.register_auto_scheduler_workload_func +def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): + # TODO: implement tile_size + tile_size = 4 #_infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name='inputs') + #weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') + N, H, W, CI = get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + # if dilation_h != 1 or dilation_w != 1: + # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, G = winograd_transform_matrices(m, r, 'float32') + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + # kernel_pack = te.compute((alpha, alpha, CO, CI), lambda eps, nu, co, ci: + # weight[0][0][0][0], + # name='kernel_pack') + kshape = (alpha, alpha, CO, CI) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu][ci], name='input_tile',) + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack', + attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_last_split_is_one": ["ci", "p"], + "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_no_cache_write": "True", + }) + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: + te.sum(data_pack[eps][nu][p][ci] * + kernel_pack[eps][nu][co][ci], + axis=[ci]), name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: + te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse', + attrs={"auto_scheduler_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], + "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], + "auto_scheduler_last_split_is_one": ["co", "p"], + "auto_scheduler_no_cache_write": "True", + }) + + # output + output = te.compute((N, H, W, CO), lambda n, h, w, co: + inverse[idxmod(h, m), + idxmod(w, m), + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + co], + name='conv2d_winograd', + tag='conv2d_winograd_nhwc', + attrs={"auto_scheduler_no_split_at_outer": ["n", "h", "w", "co"],}) + return [inputs, kernel_pack, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): + # TODO: implement tile_size + tile_size = 4 #_infer_tile_size(data, kernel) + inputs = te.placeholder((N, CI, H, W), name='inputs') + #weight = te.placeholder((CO, CI, kernel_size, kernel_size), name='weight') + N, CI, H, W = get_const_tuple(inputs.shape) + # if isinstance(dilation, int): + # dilation_h = dilation_w = dilation + # else: + # dilation_h, dilation_w = dilation + # if dilation_h != 1 or dilation_w != 1: + # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, G = winograd_transform_matrices(m, r, 'float32') + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + # kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: + # weight[0][0][0][0], + # name='kernel_pack') + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute((CI, P, alpha, alpha), lambda ci, p, eps, nu: + data_pad[idxdiv(p, (nH * nW))][ci][idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu], name='input_tile') + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: + te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack', + attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_no_split_at_outer": ["ci", "p"], + "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_no_cache_write": "True", + }) + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, CO, P), lambda eps, nu, co, p: + te.sum(data_pack[eps][nu][ci][p] * + kernel_pack[eps][nu][ci][co], + axis=[ci]), name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: + te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse', + attrs={"auto_scheduler_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], + "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], + "auto_scheduler_no_cache_write": "True"}) + + # output + output = te.compute((N, CO, H, W), lambda n, co, h, w: + inverse[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + idxmod(h, m), + idxmod(w, m)], + name='conv2d_winograd', + attrs={"auto_scheduler_no_split_at_outer": ["n", "co", "h", "w"],}) + return [inputs, kernel_pack, output] + +# ========================== Subgraphs ========================== + +@ansor.register_auto_scheduler_workload_func +def transpose_batch_matmul(batch, seq_len, n_head, n_dim): + query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') + value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') + query_T = te.compute((batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], name="query_T") + value_T = te.compute((batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], name="value_T") + k = te.reduce_axis((0, n_dim), name='k') + out = te.compute((batch, n_head, seq_len, seq_len), lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), name='C') + return [query, value, out] + +@ansor.register_auto_scheduler_workload_func +def batch_norm(M, N, eps=1e-5): + A = te.placeholder((M, N), name='A') + k1 = te.reduce_axis((0, M), name='k1') + k2 = te.reduce_axis((0, M), name='k2') + mean = te.compute((N,), lambda j: te.sum(A[k1][j] / M, axis=k1), name="mean") + var = te.compute((N,), + lambda j: te.sum((A[k2][j] - mean[j]) * (A[k2][j] - mean[j]) / (M - 1), k2), + name="var") + B = te.compute((M, N), lambda i, j: (A[i][j] - mean[j]) / te.sqrt(var[j] + eps), name='B') + + return [A, B] + +# ========================== Tune func & Dicts ========================== + +def tune_wkl(task_func_dict, shape_dict, wkl_type, args): + target = tvm.target.create(args.target) + + for wkl_meta_name, func in task_func_dict.items(): + if not args.wkl in ["all", wkl_type, wkl_meta_name]: + continue + + log_file = args.log_file or wkl_meta_name + ".json" + wkl_keys = [] + for shape in shape_dict[wkl_meta_name]: + if shape[0] == 1: + shape = list(shape) + shape[0] = args.batch_size + wkl_key = ansor.make_workload_key_func(func, shape) + + wkl_keys.append(wkl_key) + if args.fast_check: + break + + if not args.tune: + cost, gflops = replay_workload( + wkl_key, target, args.target_host, log_file, + args.local_measure, args.device_key, args.host, + args.port, args.ndk_cc, False) + # TODO(): Add log record + # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', + # workload_name, "AutoSchedule", "default", + # {"costs": [cost]}, time.time()), args.out_file) + + if args.tune: + print("========== Tune for %s (%d shapes) ========== " % (wkl_meta_name, len(wkl_keys))) + + load_log_file = args.load_log or log_file + n_trials = args.n_trials_per_shape * len(wkl_keys) + + tune_option, measure_ctx = create_tune_option(target, log_file, + n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.device_key, args.host, args.port, args.ndk_cc) + + # tune workloads jointly using JointTuner + tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, + target, args.target_host, args.policy, args.model_type, + args.load_model, load_log_file, tune_option) + + if measure_ctx: + del measure_ctx + + +single_op_task_func_dict = { + 'GMM': batch_matmul_nkkm, + 'C1D': conv1d_nlc, + 'C2D': conv2d_nhwc, + 'C3D': conv3d_ndhwc, + 'GRP': conv2d_nhwc, + 'DIL': conv2d_nhwc, + 'DEP': depthwise_conv2d_nhwc, + 'T2D': conv2d_transpose_nhwc, + 'CAP': conv2d_capsule_nhwijc, + 'NRM': norm_bmn, + #'SMX': softmax_mn, + +# The following workloads are not in our sinle op evaluation plan. +# They should be moved to `common.py` and be used by `tune_wkl.py`. +# 'C2D_NCHW': conv2d_nchw, + 'C2DWG_NHWC': conv2d_winograd_nhwc, +# 'C2DWG_NCHW': conv2d_winograd_nchw, +# 'GMM_TC': matmul_nkkm, +} + +subgraph_task_func_dict = { + 'conv2d_bn_relu': conv2d_nhwc_bn_relu, + #'conv2d_bn_relu': conv2d_nchw_bn_relu, # some old log uses conv2d_nchw_bn_relu + 'transpose_batch_matmul': transpose_batch_matmul, +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Task related options + parser.add_argument("--wkl", type=str, required=True, + help="all - For all workloads; \ + op - For all single ops; \ + subgraph - For all subgraphs; \ + Or specific wkl name") + parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') + parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--n-trials-per-shape", type=int, default=1000) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--fast-check", action='store_true', + help='Only run one shape for each workload. This is used for fast checking') + + # Strategy related options + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') + parser.add_argument("--task-scheduler", type=str, default='gradient', + choices=['no', 'gradient', 'round-robin'], + help='The strategy of task scheduler') + + # File related options + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + parser.add_argument("--out-file", type=str, default='results.tsv') + + # Detailed control options + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--device-key", type=str, default=None) + parser.add_argument("--host", type=str, default='0.0.0.0') + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--n-parallel", type=int, default=1) + parser.add_argument("--ndk-cc", type=str, default=None) + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + logging.basicConfig() + logging.getLogger('ansor').setLevel(logging.DEBUG) + + # compute the number of tasks + num_tasks = 0 + for wkl_meta_name in single_op_task_func_dict: + if not args.wkl in ["all", "op", wkl_meta_name]: + continue + if args.fast_check: + num_tasks += 1 + else: + num_tasks += len(single_op_shape_dict[wkl_meta_name]) + for wkl_meta_name in subgraph_task_func_dict: + if not args.wkl in ["all", "subgraph", wkl_meta_name]: + continue + if args.fast_check: + num_tasks += 1 + else: + num_tasks += len(subgraph_shape_dict[wkl_meta_name]) + print("Number of tasks: %d\tTotal trials: %d" % (num_tasks, num_tasks * args.n_trials_per_shape)) + + # tune for tasks + tune_wkl(single_op_task_func_dict, single_op_shape_dict, "op", args) + tune_wkl(subgraph_task_func_dict, subgraph_shape_dict, "subgraph", args) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 08f0cc19ade2..d6f552affbb1 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -12,26 +12,61 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool +def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, + n_parallel, build_timeout, local_measure, device_key, host, + port, ndk_cc, early_stopping=-1): + builder = runner = measure_ctx = None + if local_measure: + builder = ansor.LocalBuilder(timeout=build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') + runner = ansor.RPCRunner(key=device_key, host=host, port=port, + n_parallel=n_parallel, repeat=1, min_repeat_ms=400) + + tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, + num_measure_per_iter=num_measure_per_iter, + verbose=verbose, + builder=builder, + runner=runner, + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + + return tune_option, measure_ctx + def replay_workload(wkl_key, target, target_host, log_file, local_measure=True, device_key=None, host="0.0.0.0", - port=9190, ndk_cc=None): + port=9190, ndk_cc=None, show_lower_result=True): + cost = gflops = None + inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) if inp is None: print("Cannot find log for: %s" % (wkl_key)) else: dag = ansor.workload_key_to_dag(inp.task.workload_key) + print("Found schedule for: %s" % (wkl_key)) + s, bufs = dag.apply_steps_from_state(inp.state) + if show_lower_result: + print(tvm.lower(s, bufs, simple_mode=True)) - print("Found schedule for: %s" % (wkl_key)) - print(tvm.lower(s, bufs, simple_mode=True)) if local_measure: remote = None else: remote = request_remote(device_key, host, port, 1) + cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) + (gflops, cost * 1e3)) + + return cost, gflops def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, @@ -99,6 +134,7 @@ def objective_func(costs): parser.add_argument("--num-measure-per-iter", type=int, default=48, help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + # Strategy related options parser.add_argument("--seed", type=int, default=0, help='random seed') parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') @@ -106,13 +142,15 @@ def objective_func(costs): parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + # File related options parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Detailed control options parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) @@ -124,40 +162,21 @@ def objective_func(costs): np.random.seed(args.seed) random.seed(args.seed) - logging.basicConfig() logging.getLogger('ansor').setLevel(logging.DEBUG) - log_file = args.log_file or args.wkl + ".json" - - target = tvm.target.create(args.target) wkl_keys = get_workload_keys(args.wkl) + target = tvm.target.create(args.target) + log_file = args.log_file or args.wkl + ".json" if args.tune: load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) - builder = runner = measure_ctx = None - if args.local_measure: - builder = ansor.LocalBuilder(timeout=args.build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) - else: - os.environ['TVM_NDK_CC'] = args.ndk_cc - builder = ansor.LocalBuilder(timeout=args.build_timeout, build_func='ndk') - runner = ansor.RPCRunner(args.device_key, host=args.host, port=args.port, - repeat=1, min_repeat_ms=400, n_parallel=args.n_parallel) - - tune_option = ansor.TuneOption(n_trials=args.n_trials, - num_measure_per_iter=args.num_measure_per_iter, - verbose=args.verbose, - builder=builder, - runner=runner, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + tune_option, measure_ctx = create_tune_option(target, log_file, + args.n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.device_key, args.host, args.port, args.ndk_cc) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index d52b868e180d..c07a3af7473c 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -54,9 +54,11 @@ void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { measured_states_set_.insert(state.ToStr()); } - StdCout(verbose_) << "Measured States Set: " - << measured_states_set_.size() + StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size() << " state hashes loaded from " << log_file << std::endl; + } else { + StdCout(verbose_) << "Measured States Set: no states found from " + << log_file << std::endl; } } From 2f241ed4f83763979e827da2d6cca55c7f28cb77 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 12 Jun 2020 15:42:38 -0700 Subject: [PATCH 24/78] add explicit_unroll_max_extent (#25) --- src/tir/transforms/unroll_loop.cc | 19 ++++++++++++--- .../test_tir_transform_unroll_loop.py | 24 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a15190665949..1c84304fb0e7 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -43,6 +43,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; + int explicit_unroll_max_extent; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -57,6 +58,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); + TVM_ATTR_FIELD(explicit_unroll_max_extent) + .describe("The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)") + .set_default(32); } }; @@ -71,11 +75,12 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll) + bool explicit_unroll, int explicit_unroll_max_extent) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) {} + explicit_unroll_(explicit_unroll), + explicit_unroll_max_extent_(explicit_unroll_max_extent) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -165,6 +170,11 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); + if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { + // Do not unroll too long loops + ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; + return ForNode::make(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); + } Stmt body = op->body; Map vmap; Array unrolled; @@ -197,7 +207,10 @@ class LoopUnroller : public StmtExprMutator { // max extent of loop to auto unroll // this not not count the total steps, only count the number of loops int auto_max_extent_; + // Whether to explicitly unroll the loop instead of setting a pragma bool explicit_unroll_; + // The maximum extent of a loop that can be unrolled explicitly (-1 means infinite) + int explicit_unroll_max_extent_; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. @@ -210,7 +223,7 @@ class LoopUnroller : public StmtExprMutator { Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll)(stmt); + cfg->explicit_unroll, cfg->explicit_unroll_max_extent)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 68639940bb05..12c686634548 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -110,7 +110,31 @@ def test_unroll_single_count_loops(): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert ret == stmt +def test_unroll_explicitly_max_extent(): + n = 64 + A = te.placeholder((n,), name='A') + B = te.compute((n,), lambda *i: A(*i), name='B') + s = te.create_schedule(B.op) + s = s.normalize() + dom_map = tvm.te.schedule.InferBound(s) + stmt = tvm.te.schedule.ScheduleOps(s, dom_map) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"explicit_unroll_max_extent": n-1} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert tvm.ir.structural_equal(ret, stmt) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"explicit_unroll_max_extent": n} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert not tvm.ir.structural_equal(ret, stmt) + + if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() + test_unroll_explicitly_max_extent() From 18d44b8cff0a7048e79394d9ef16da986ebc3ca5 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 15 Jun 2020 18:17:56 +0800 Subject: [PATCH 25/78] Add Index simplification & API update (#26) * Add vectorized cooperative_fetching test * Update math simplify for vectorized CF * File rename * Update tune_network * API update --- python/tvm/ansor/auto_schedule.py | 3 - python/tvm/ansor/compute_dag.py | 24 ++- python/tvm/ansor/feature.py | 10 +- python/tvm/ansor/loop_state.py | 14 +- python/tvm/ansor/measure.py | 2 +- python/tvm/ansor/serialization.py | 4 +- scripts/tune_network.py | 136 ++++++++-------- scripts/tune_test.py | 4 +- src/ansor/loop_state.cc | 5 - .../search_policy/meta_tile_rewrite_policy.cc | 3 + src/arith/rewrite_simplify.cc | 71 +++++++- tests/python/unittest/test_ansor_common.py | 2 +- .../python/unittest/test_ansor_compute_dag.py | 3 +- tests/python/unittest/test_ansor_feature.py | 4 +- tests/python/unittest/test_ansor_measure.py | 2 +- ...t_ansor_vectorized_cooperative_fetching.py | 152 ++++++++++++++++++ tutorials/ansor/tune_conv2d_cuda.py | 8 +- tutorials/ansor/tune_simple_subgraph.py | 8 +- 18 files changed, 344 insertions(+), 111 deletions(-) create mode 100644 tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 09895302d25a..127be4c7ad22 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -245,8 +245,6 @@ def auto_schedule(workload, target=None, Returns ------- - state : State - sch : tvm.Schedule tensors : List[Tensor] @@ -270,4 +268,3 @@ def auto_schedule(workload, target=None, else: raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") - diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 0c8aa2055482..23ba1b32f5c4 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -20,7 +20,7 @@ import tvm._ffi from tvm.runtime import Object from tvm import te -from .loop_state import State +from .loop_state import State, StateObject from . import _ffi_api @@ -63,8 +63,12 @@ def apply_steps_from_state(self, state, layout_rewrite_level=None): sch : Schedule args : List[Tensor] """ - sch, args = _ffi_api.ComputeDAGApplyStepsFromState(self, state) - return sch, args + if isinstance(state, State): + return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object) + elif isinstance(state, StateObject): + return _ffi_api.ComputeDAGApplyStepsFromState(self, state) + else: + raise ValueError("The input must be a State or StateObject") def print_python_code_from_state(self, state): """ @@ -76,7 +80,12 @@ def print_python_code_from_state(self, state): ------- str : Str """ - return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) + if isinstance(state, State): + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state.state_object) + elif isinstance(state, StateObject): + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) + else: + raise ValueError("The input must be a State or StateObject") def infer_bound_from_state(self, state): """ @@ -88,7 +97,12 @@ def infer_bound_from_state(self, state): ------- state : StateObject """ - return _ffi_api.ComputeDAGInferBoundFromState(self, state) + if isinstance(state, State): + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object)) + elif isinstance(state, StateObject): + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state)) + else: + raise ValueError("The input must be a State or StateObject") def gen_schedule(state, bufs): if not state or not state.complete: diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index f91d7da169f5..4f9fdeb9e6cd 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -23,7 +23,7 @@ import struct import numpy as np -from .loop_state import StateObject +from .loop_state import State, StateObject from .measure import MeasureInput, MeasureResult from . import _ffi_api @@ -131,12 +131,16 @@ def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], return unpack_feature(byte_arr) -def get_per_stmt_features_from_states(states: List[StateObject], +def get_per_stmt_features_from_states(states, task: "SearchTask", max_n_bufs: int = None) -> List[np.ndarray]: """Get per_stmt features from states""" + if isinstance(states[0], State): + state_objects = [s.state_object for s in states] + elif isinstance(states[0], StateObject): + state_objects = states byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( - states, task, max_n_bufs or DEFAULT_MAX_N_BUFS) + state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS) return unpack_feature(byte_arr)[0] diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 557bb9d3102b..0cf157147423 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -60,18 +60,6 @@ def iters(self): setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self)) return getattr(self, "iterators_cache") - def iter(self, index): - """ - Parameters - ---------- - index : Int - - Returns - ------- - iter : Iterator - """ - return _ffi_api.StageGetIterator(self, index) - @tvm._ffi.register_object("ansor.State") class StateObject(Object): @@ -302,7 +290,7 @@ def bind_thread(self, stage_id, it, thread_name): } thread_id = trans_table[thread_name] - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, thread_id) + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id) self.clear_cache() return res diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 610e9529090f..b82327ec67c4 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -62,7 +62,7 @@ class MeasureInput(Object): """ def __init__(self, task, state): - self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state) + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) @tvm._ffi.register_object("ansor.BuildResult") diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 3d7ed7733a78..e11a589a7522 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -22,6 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import MeasureCallback, MeasureErrorNo +from .loop_state import State from . import _ffi_api @@ -74,7 +75,8 @@ def write_measure_records_to_file(filename, inputs, results): def get_states_from_measure_inputs(inputs, task): """Get states from measure inputs""" - return _ffi_api.GetStatesFromMeasureInputs(inputs, task) + state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task) + return [State(s) for s in state_objects] def best_measure_pair_in_file(filename, workload_key=None, target=None): diff --git a/scripts/tune_network.py b/scripts/tune_network.py index f1f7cd54f8c6..5f22e31d50f7 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -191,22 +191,14 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof return module, ctx -def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, - local_measure, device_key, host, port, n_parallel, ndk_cc, - build_timeout, run_timeout, num_threads, tune, check_correctness, - debug_profile, tuning_parameters, record_file, layout_set): - task_scheduler, model_type, policy, log_file, load_log_file = (tuning_parameters['task_scheduler'], - tuning_parameters['model_type'], tuning_parameters['policy'], - tuning_parameters['log_file'], tuning_parameters['load_log_file']) - - if layout_set: - layout = layout_set - +def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, + debug_profile, check_correctness, network_parameters, + task_scheduler_parameters, tune_parameters, module_parameters): # Extract workloads from relay program - print("=============== Extract workloads ===============") - mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout) + mod, params, input_name, data_shape, out_shape = get_network(**network_parameters) if tune: + print("=============== Extracting workloads ===============") workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, @@ -215,7 +207,7 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul, relay.op.mean, )) - print("Total workload number: %d" % (len(workloads))) + print("Totally %d workload extracted." % (len(workloads))) # Tune workloads with auto scheduler print("=============== Tuning ===============") @@ -225,23 +217,13 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, print("[========= Task %d =========]\n" % i, dag) tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) - def objective_func(costs): - return sum(c * w for c, w in zip(costs, wkl_weights)) - - tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, - load_log_file=load_log_file, - load_model_file=tuning_parameters['load_model']) + tuner = ansor.SimpleTaskScheduler(tasks, + lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), + **task_scheduler_parameters) + tune_option, measure_ctx = create_tune_option(target, **tune_parameters) - tune_option, measure_ctx = create_tune_option(target, log_file, - tuning_parameters['n_trials'], tuning_parameters['num_measure_per_iter'], - tuning_parameters['verbose'], n_parallel, build_timeout, - local_measure, device_key, host, port, ndk_cc, - tuning_parameters['early_stopping']) - search_policy = "%s.%s" % (policy, model_type) - - if local_measure and target.target_name != 'cuda': + if tune_parameters['local_measure'] and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" - tuner.tune(tune_option, search_policy) if measure_ctx: @@ -251,15 +233,13 @@ def objective_func(costs): # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with ansor.apply_history_best(log_file, args.log_n_lines): - #if True: - #with ansor.BlockingEmptyContext(): + with ansor.apply_history_best(tune_parameters['log_file'], log_n_lines): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" os.environ['TVM_BIND_MASTER_CORE_0'] = "1" if kernel_layout_rewrite: ansor.prepare_layout_rewrite(mod, target=target, - params=params, - ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) + params=params, + ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) else: # disable layout rewrite ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE @@ -268,22 +248,12 @@ def objective_func(costs): with relay.build_config(opt_level=3): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) - ''' - from tvm.relay.backend import graph_runtime_codegen - with relay.build_config(opt_level=3): - opt_mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(opt_mod["main"]) - with tvm.transform.PassContext(opt_level=3): - graph, lib, opt_params = relay.build_module.build( - mod, target=target, params=params) - ''' + ansor.finish_layout_rewrite() print("=============== Compile Finish ===============") - module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params, - debug_profile, local_measure, ndk_cc, - device_key, host, port, run_timeout, num_threads) + module, ctx = create_module(data_shape, graph, lib, target, input_name, + opt_params, debug_profile, **module_parameters) # Evaluate print("========== Evaluate ==========") @@ -315,9 +285,8 @@ def objective_func(costs): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) - module, _ = create_module(data_shape, graph, lib, target, input_name, opt_params, - debug_profile, local_measure, ndk_cc, - device_key, host, port, run_timeout, num_threads) + module, _ = create_module(data_shape, graph, lib, target, input_name, + opt_params, debug_profile, **module_parameters) module.run() expected_output = module.get_output(0).asnumpy() @@ -343,7 +312,7 @@ def objective_func(costs): # Strategy related options parser.add_argument("--seed", type=int, default=0, help='random seed') parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + default='meta-rewrite') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='gradient', choices=['no', 'gradient', 'round-robin'], @@ -359,6 +328,7 @@ def objective_func(costs): # Detailed control options parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=10) + parser.add_argument("--early-stopping", type=int, default=-1) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) @@ -375,23 +345,59 @@ def objective_func(costs): logging.getLogger('ansor').setLevel(logging.DEBUG) target = tvm.target.create(args.target) + log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, + target.target_name) + load_log_file = args.load_log or log_file + search_policy = "%s.%s" % (args.policy, args.model_type) + if args.layout: + layout = args.layout + elif target.target_name == "cuda": + layout = "NCHW" + else: + layout = "NHWC" + + network_parameters = { + 'name': args.network, + 'model_path': args.model_path, + 'batch_size': args.batch_size, + 'layout': layout + } + + task_scheduler_parameters = { + 'strategy': args.task_scheduler, + 'load_log_file': load_log_file, + 'load_model_file': args.load_model, + 'verbose': args.verbose, + } - tuning_parameters = { + control_parameters = { + 'local_measure': args.local_measure, + 'device_key': args.device_key, + 'host': args.host, + 'port': args.port, + 'ndk_cc': args.ndk_cc, + } + + tune_parameters = { + 'log_file': log_file, 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, - 'log_file': args.log_file or "%s-B%d.json" % (args.network, args.batch_size), - 'load_model': args.load_model, - 'model_type': args.model_type, - 'task_scheduler': args.task_scheduler, - 'policy': args.policy, - 'early_stopping': -1, - 'verbose': 1, + 'verbose': args.verbose, + 'n_parallel': args.n_parallel, + 'build_timeout': args.build_timeout, + 'run_timeout': args.run_timeout, + 'early_stopping': args.early_stopping, + **control_parameters + } + + module_parameters = { + 'run_timeout': args.run_timeout, + 'num_threads': args.num_threads, + **control_parameters } - tuning_parameters['load_log_file'] = args.load_log or tuning_parameters['log_file'] os.environ["TOPHUB_LOCATION"] = "NONE" - tune_and_evaluate(args.network, args.model_path, args.batch_size, target, args.target_host, - args.local_measure, args.device_key, args.host, - args.port, args.n_parallel, args.ndk_cc, args.build_timeout, - args.run_timeout, args.num_threads, args.tune, args.check_correctness, - args.debug_profile, tuning_parameters, args.out_file, args.layout) + tune_and_evaluate(target, args.target_host, args.log_n_lines, search_policy, + args.tune, args.debug_profile, args.check_correctness, + network_parameters, task_scheduler_parameters, tune_parameters, + module_parameters) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index d6f552affbb1..a49ecd088afc 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -14,7 +14,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, n_parallel, build_timeout, local_measure, device_key, host, - port, ndk_cc, early_stopping=-1): + port, ndk_cc, early_stopping=-1, run_timeout=10): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -26,7 +26,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=device_key, host=host, port=port, + runner = ansor.RPCRunner(key=device_key, host=host, port=port, timeout=run_timeout, n_parallel=n_parallel, repeat=1, min_repeat_ms=400) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index faaac94f3323..77361dbf837c 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1019,11 +1019,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PrintState(&p->stream, node, true); }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterator").set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; -}); - TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { return Array(stage->iters); }); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 5703e17ba29f..4a045d31a487 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -1267,6 +1267,9 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m if (InitPopulationThreadBind(this, &tmp_s)) { continue_count++; + if (continue_count == out_size) { + StdCout(verbose_) << "Initial Population Sampling..." << std::endl; + } continue; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4887ef0ee47d..d3af64a4f576 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -132,6 +132,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); + if ((x + broadcast(y, lanes)).Match(ret)) { + if (auto ps = y.Eval().as()) { + if (ps->value == 0.0) { + return x.Eval(); + } + } + } } if (IsIndexType(op->dtype)) { @@ -422,6 +429,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); + if ((broadcast(x, lanes) * y).Match(ret)) { + if (auto ps = x.Eval().as()) { + if (ps->value == 0.0) { + return make_const(op->dtype, 0.0); + } + } + } } if (IsIndexType(op->dtype)) { @@ -700,9 +714,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar w, x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3, c4; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -767,6 +781,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x * c1, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); + // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -783,6 +802,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y * c2 + z, c3), floordiv(x * c1 + y * c2, c3), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && c3.Eval()->value > 0 && + c3.Eval()->value % c1.Eval()->value == 0 && + c3.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), + std::max(-c1.Eval()->value, -c2.Eval()->value) + 1)); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -807,6 +833,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); + + // Rules involving 4-operands + TVM_TRY_REWRITE_IF(floordiv(w * c1 + x * c2 + y * c3 + z, c4), + floordiv(w * c1 + x * c2 + y * c3, c4), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c3.Eval()->value > 0 && c4.Eval()->value > 0 && + c4.Eval()->value % c1.Eval()->value == 0 && + c4.Eval()->value % c2.Eval()->value == 0 && + c4.Eval()->value % c3.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), + std::max(-c1.Eval()->value, + std::max(-c2.Eval()->value, -c3.Eval()->value)) + 1)); } return ret; } @@ -818,9 +856,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar w, x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2, c3, c4; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -864,6 +902,31 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); + + // TODO(jcf94): For the next three rules, better use the max common factor + // of c1, c2, c3 to do the simplify + TVM_TRY_REWRITE_IF(floormod(x * c1 + y * c2 + z, c3), + floormod(x * floordiv(c1, c2) + y, floordiv(c3, c2)) * c2 + z, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c3.Eval()->value > 0 && + c3.Eval()->value % c2.Eval()->value == 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), -c2.Eval()->value + 1)); + + TVM_TRY_REWRITE_IF(floormod(w * c1 + x * c2 + y * c3 + z, c4), + floormod(w * floordiv(c1, c3) + x * floordiv(c2, c3) + y, + floordiv(c4, c3)) * c3 + z, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c3.Eval()->value > 0 && c4.Eval()->value > 0 && + c4.Eval()->value % c3.Eval()->value == 0 && + c1.Eval()->value % c3.Eval()->value == 0 && + c2.Eval()->value % c3.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), -c3.Eval()->value + 1)); + // try modular analysis if (floormod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 1790b06bcb60..e23dba2aa4e3 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -77,5 +77,5 @@ def get_tiled_matmul(): C += 1 C_global += 1 s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) - return dag, s0.state_object + return dag, s0 diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index b60136d4265f..b8afcc6a5b23 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -33,7 +33,6 @@ def test_apply_steps(): def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - s = ansor.loop_state.State(s) A_global, B_global, C_global = 1, 3, 4 assert s.stages[B_global].iters[0].range.extent == 512 @@ -62,7 +61,7 @@ def test_lower_legalize_invalid_attach(): s.compute_at(A, B, s.stages[B].iters[1]) s.split(B, s.stages[B].iters[1], [2]) - sch, tensors = dag.apply_steps_from_state(s.state_object) + sch, tensors = dag.apply_steps_from_state(s) stmt = tvm.lower(sch, tensors, simple_mode=True) diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index 3da1c7aa332e..567fc080c6f8 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -47,7 +47,7 @@ def test_cpu_matmul(): target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] stage_0 = fea[0] assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) @@ -91,7 +91,7 @@ def fusion_test(N, M): target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] found = False for stage_fea in fea: diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index 2ac54d3c765b..d457dd2c55cc 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -43,7 +43,7 @@ def test_serialization(): s2 = dag.infer_bound_from_state(inputs[0].state) assert s1 == s2 - assert not (s1 == dag.get_init_state().state_object) + assert not (s1 == dag.get_init_state()) def test_measure_local_builder_runner(): diff --git a/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py b/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py new file mode 100644 index 000000000000..c41abc7bcb3d --- /dev/null +++ b/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test for vectorized cooperative fetching """ + +import numpy as np +import tvm +from tvm import ansor, te +import topi + +from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu + + +def init_common(): + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + s0 = dag.get_init_state() + A, B, C = 0, 1, 2 + B_shared = s0.cache_read(B, "shared", [C], dag) + C += 1 + B_local = s0.cache_read(B_shared, "local", [C], dag) + C += 1 + A_shared = s0.cache_read(A, "shared", [C], dag) + B += 1 + B_shared += 1 + B_local += 1 + C += 1 + A_local = s0.cache_read(A_shared, "local", [C], dag) + B += 1 + B_shared += 1 + B_local += 1 + C += 1 + + return A_shared, A_local, B_shared, B_local, C, dag, s0 + +def check_common(dag, state): + s, args = dag.apply_steps_from_state(state) + # To check if every vectorize loop transforms to ramp expr successfully + # TODO(jcf94): Find a better way to process the check in AST + print(tvm.lower(s, args)) + + if tvm.context("cuda", 0).exist: + tgt = tvm.target.cuda() + mod = tvm.build(s, args, tgt) + # To check if every vectorize loop transforms to correct instruction + print(mod.imported_modules[0].get_source()) + + ctx = tvm.context("cuda", 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + else: + print("CUDA device not found, skip this test.") + +def test_vectorized_cooperative_fetching_x(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() + + its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) + s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) + s0.bind_thread(C, s0.stages[C].iters[1], "vthread") + s0.fuse(C, [s0.stages[C].iters[2], s0.stages[C].iters[3]]) + s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0.stages[C].iters[3]) + fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [64, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.vectorize(B_shared, its[2]) + s0.compute_at(B_local, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0.stages[C].iters[3]) + fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [64, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.vectorize(A_shared, its[2]) + s0.compute_at(A_local, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + check_common(dag, s0) + +def test_vectorized_cooperative_fetching_xy(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() + + its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) + s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) + s0.bind_thread(C, s0.stages[C].iters[1], "vthread") + s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") + s0.bind_thread(C, s0.stages[C].iters[3], "threadIdx.y") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [8, 8, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.bind_thread(B_shared, its[2], "threadIdx.y") + s0.vectorize(B_shared, its[3]) + s0.compute_at(B_local, C, s0.stages[C].iters[5]) + fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [8, 8, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.bind_thread(A_shared, its[2], "threadIdx.y") + s0.vectorize(A_shared, its[3]) + s0.compute_at(A_local, C, s0.stages[C].iters[5]) + fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + check_common(dag, s0) + +if __name__ == "__main__": + test_vectorized_cooperative_fetching_x() + test_vectorized_cooperative_fetching_xy() diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index caa040d1b3bc..14a6ee797276 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -122,11 +122,17 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # During the searching process, we may generate several invalid schedules and they # will be filtered out. It's fine to see "Encountered errors during feature extraction." # in the tuning logs. +# :code:`ansor.LogToFile` callback will log the tuning results into a +# log file, which can be used to get the best config later. +# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# from history log before schedule search, we can add this callback to make +# sure a same schedule will never be measured for multiple times. measure_ctx = ansor.LocalRPCMeasureContext(repeat=3, min_repeat_ms=100, timeout=4) tune_option = ansor.TuneOption(n_trials=20, runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) print("==== Get Lowered Stmt ====") diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index fedbb399d0cf..dfd36e89fd4c 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -146,8 +146,11 @@ def matmul_add(N, L, M, dtype): # # We only make 5 trials in this tutorial for demonstration. In practice, # you can do more trials according to your time budget. -# The :code:`ansor.LogToFile` callback will log the tuning results into a +# :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. +# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# from history log before schedule search, we can add this callback to make +# sure a same schedule will never be measured for multiple times. log_file = "matmul_add.json" @@ -157,7 +160,8 @@ def matmul_add(N, L, M, dtype): search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, - measure_callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) ################################################################ # Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high From 4ea67122b82a9ce50f5299018a99eeeed1d37ee5 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 16 Jun 2020 14:00:52 +0800 Subject: [PATCH 26/78] Update PreLoadMeasuredStates & Some bug fix (#27) * Add a threading wrapper to fix the test bug * Set default TVM_USE_AUTO_SCHEDULER to false * Update PreLoadMeasuredStates callback --- python/tvm/ansor/auto_schedule.py | 3 +- python/tvm/ansor/relay_integration.py | 18 +++- python/tvm/ansor/task_scheduler.py | 11 +++ scripts/tune_network.py | 2 +- .../search_policy/meta_tile_rewrite_policy.h | 6 -- src/ansor/search_policy/search_policy.cc | 30 ++++-- src/ansor/search_policy/search_policy.h | 4 + src/ansor/search_policy/utils.cc | 5 +- .../unittest/test_ansor_relay_Integration.py | 96 +++++++++++++++++++ .../unittest/test_ansor_search_policy.py | 8 +- .../unittest/test_ansor_task_scheduler.py | 19 +++- topi/python/topi/arm_cpu/__init__.py | 2 +- topi/python/topi/generic/__init__.py | 2 +- topi/python/topi/x86/__init__.py | 2 +- 14 files changed, 178 insertions(+), 30 deletions(-) create mode 100644 tests/python/unittest/test_ansor_relay_Integration.py diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 127be4c7ad22..232c24ee89ea 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -81,6 +81,7 @@ def set_verbose(self, verbose): def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) + @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") class MetaTileRewritePolicy(SearchPolicy): """ The search policy that searches with meta tiling and random rewrite @@ -231,7 +232,7 @@ def auto_schedule(workload, target=None, Parameters ---------- - workload : Str or SearchTask + workload : Union[SearchTask, str] target : Target diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index de2e12e389e7..348828eec4b4 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -20,6 +20,9 @@ 99.9% copy-paste of implementation by @MerryMercy """ +import os +os.environ['TVM_USE_AUTO_SCHEDULER'] = 'true' + import threading import warnings import tvm @@ -95,7 +98,7 @@ def init_op_to_schedule_map(): relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul], } -def extract_from_program(mod, params, ops, target, target_host=None): +def extract_from_program(mod, params, target, target_host=None, ops=None): """ Extract tuning tasks from a relay program. This function is the single program version of extract_from_multiple_program. @@ -117,9 +120,9 @@ def extract_from_program(mod, params, ops, target, target_host=None): ------- workloads: Array of Tuple(wkl_key, target) """ - return extract_from_multiple_program([mod], [params], ops, target, target_host) + return extract_from_multiple_program([mod], [params], target, target_host, ops) -def extract_from_multiple_program(mods, params, ops, target, target_host=None): +def extract_from_multiple_program(mods, params, target, target_host=None, ops=None): """ Extract tuning tasks from multiple relay programs. This function collects tuning tasks by building a list of programs @@ -148,6 +151,15 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None): init_op_to_schedule_map() topi_scheds = [] + + if not ops: + ops = [relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, + relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d, + relay.op.nn.avg_pool2d, relay.op.nn.global_max_pool2d, + relay.op.nn.global_avg_pool2d, relay.op.nn.conv3d, + relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul, + relay.op.mean] + for op_name in ops: if op_name in OP_TO_SCHEDULE: topi_scheds.extend(OP_TO_SCHEDULE[op_name]) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 082b2d265140..f8d3f419dcb4 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -145,6 +145,17 @@ def __init__(self, self.sequential_now_task_begin_ct = 0 def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + """ Tune tasks. + + Notice: This method does not have return value, make sure to set `LogToFile` + measure callback in `tune_option`. + + Parameters + ---------- + tune_option: TuneOption + + search_policy: Str or List[SearchPolicy] + """ # init members self.task_cts = [0 for _ in range(len(self.tasks))] self.task_costs_history = [[] for _ in range(len(self.tasks))] diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 5f22e31d50f7..5e5a337c7bce 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -7,7 +7,7 @@ import numpy as np import tvm -from tvm import _ffi, relay, ansor +from tvm import _ffi, ansor, relay import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index befc002b6aa2..6930a71038a3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -103,12 +103,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration - - // The array of already measured states. - std::vector measured_states_vector_; - - // The throughputs of already measured states - std::vector measured_states_throughputs_; }; TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index c07a3af7473c..685052f3f71f 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -37,28 +37,44 @@ TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode); void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { LogReader reader = LogReaderNode::make(log_file); const auto& res = reader->ReadLines(-1); - if (res.first.size()) { + size_t log_size = res.first.size(); + CHECK_EQ(log_size, res.second.size()); + if (log_size) { std::vector measured_states; - for (const auto& inp : res.first) { + std::vector measured_throughputs; + for (size_t i = 0; i < log_size; i++) { + const auto& inp = res.first[i]; if (inp->task->workload_key == cur_task_->workload_key && inp->task->target->target_name.compare( cur_task_->target->target_name) == 0) { State state = cur_task_->compute_dag.GetInitState(); state.CopyOnWrite()->transform_steps = inp->state->transform_steps; state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag); - measured_states.push_back(std::move(state)); + measured_states.emplace_back(std::move(state)); + measured_throughputs.push_back(res.second[i]->error_no == 0 ? + (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); } } cur_task_->compute_dag.InferBound(&measured_states); - for (auto state : measured_states) { - measured_states_set_.insert(state.ToStr()); + for (size_t i = 0; i < measured_states.size(); i ++) { + auto& state = measured_states[i]; + const auto& state_str = state.ToStr(); + if (!measured_states_set_.count(state_str)) { + measured_states_set_.insert(state_str); + if (measured_throughputs[i] != 0.0) { + measured_states_vector_.emplace_back(std::move(state)); + measured_states_throughputs_.emplace_back(measured_throughputs[i]); + } + } } StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size() - << " state hashes loaded from " << log_file << std::endl; + << " state hashes loaded from " << log_file + << " for " << cur_task_->workload_key << std::endl; } else { StdCout(verbose_) << "Measured States Set: no states found from " - << log_file << std::endl; + << log_file << " for " << cur_task_->workload_key + << std::endl; } } diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 2dfbd9429648..6085fd1816e8 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -101,6 +101,10 @@ class SearchPolicyNode : public Object { // The set of the already measured states. // We store the string format for redundancy check std::unordered_set measured_states_set_; + // The array of already measured states. + std::vector measured_states_vector_; + // The throughputs of already measured states + std::vector measured_states_throughputs_; }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index 608b89da118c..e0fd00b23e7b 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -311,9 +311,10 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split CHECK(ps != nullptr); extent = GetIntImm(ps->extent); retry_ct += 1; - } while (retry_ct < static_cast(split_step_ids.size()) << 2 && extent == 1); + } while (retry_ct < static_cast(split_step_ids.size()) << 2 && + (extent == 1 || extent == 0)); - if (extent == 1) { + if (extent == 0 || extent == 1) { return State(); } diff --git a/tests/python/unittest/test_ansor_relay_Integration.py b/tests/python/unittest/test_ansor_relay_Integration.py new file mode 100644 index 000000000000..9c423220844c --- /dev/null +++ b/tests/python/unittest/test_ansor_relay_Integration.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Relay Integration """ + +import tempfile +import numpy as np + +import tvm +from tvm import ansor, relay +import tvm.contrib.graph_runtime as runtime + +from test_ansor_common import get_tiled_matmul + +def dense_graph(N, dtype="float32"): + ori_data = relay.var("data", shape=(N, N), dtype=dtype) + weight = relay.var("weight", shape=(N, N), dtype=dtype) + data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) + dense = relay.nn.dense(data, weight, out_dtype=dtype) + dense = relay.add(dense, weight) + dense = relay.nn.dense(dense, weight, out_dtype=dtype) + return ori_data, weight, dense + +def test_dense_integration(): + N = 128 + data, weight, dense = dense_graph(N) + mod = relay.Function([data, weight], dense) + mod = tvm.IRModule.from_expr(mod) + + ctx = tvm.context("llvm") + target = tvm.target.create("llvm") + d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx) + w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx) + workloads, wkl_weights = ansor.extract_from_program(mod, {}, target=target) + + assert len(workloads) == 2 + assert len(wkl_weights) == 2 + + tasks = [] + for wkl_key in workloads: + dag = ansor.workload_key_to_dag(wkl_key) + tasks.append(ansor.SearchTask(dag, wkl_key, target)) + + assert str(tasks[0].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ + "placeholder = PLACEHOLDER [128, 128]\n" + \ + "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ + "compute(y, x) += compute[y, x, kk]\n" + + assert str(tasks[1].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ + "placeholder = PLACEHOLDER [128, 128]\n" + \ + "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ + "compute(y, x) += compute[y, x, kk]\n" + \ + "T_add(ax0, ax1) = (compute[ax0, ax1] + placeholder[ax0, ax1])\n" + + tuner = ansor.SimpleTaskScheduler(tasks) + measure_ctx = ansor.LocalRPCMeasureContext() + with tempfile.NamedTemporaryFile() as fp: + tuner.tune(ansor.TuneOption(n_trials=4, runner=measure_ctx.runner, + measure_callbacks=[ansor.LogToFile(fp.name)])) + with ansor.apply_history_best(fp.name): + with relay.build_config(opt_level=3): + graph, lib, opt_params = relay.build_module.build( + mod, target=target) + + m = runtime.create(graph, lib, ctx) + m.set_input('data', d) + m.set_input('weight', w) + m.run() + res = m.get_output(0) + if measure_ctx: + del measure_ctx + + d = d.asnumpy() + d = d * 2 + w = w.asnumpy() + d = np.dot(d, np.transpose(w)) + d = d + w + d = np.dot(d, np.transpose(w)) + + tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5) + +if __name__ == "__main__": + test_dense_integration() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index b86dfa95f9bd..839992c67e0f 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -20,6 +20,7 @@ import random import numpy as np import tempfile +import threading import tvm from tvm import ansor @@ -73,8 +74,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' def test_search_basic(): - search_common(seed=944563397) - + # Ansor search process with local runner has some modification on thread + # binding, wrap this to a subprocess to eliminate the impacts to other tests + t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) + t.start() + t.join() def test_search_xgb_model_rpc_runner(): measure_ctx = ansor.LocalRPCMeasureContext() diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py index e95d65d4b5ce..53cf2059c1f3 100644 --- a/tests/python/unittest/test_ansor_task_scheduler.py +++ b/tests/python/unittest/test_ansor_task_scheduler.py @@ -17,6 +17,8 @@ """Test the task scheduler """ +import threading + import tvm from tvm import ansor @@ -30,13 +32,20 @@ def test_task_scheduler_basic(): task1 = ansor.SearchTask(dag, "test", tgt) task2 = ansor.SearchTask(dag, "test", tgt) - def objective(costs): - return sum(costs) + def basic_test_func(task1, task2): + def objective(costs): + return sum(costs) - task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) - tune_option = ansor.TuneOption(n_trials=3, runner='local') + task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) + tune_option = ansor.TuneOption(n_trials=3, runner='local') + task_scheduler.tune(tune_option) - task_scheduler.tune(tune_option) + # Ansor search process with local runner has some modification on thread + # binding, wrap this to a subprocess to eliminate the impacts to other tests + t = threading.Thread(target=basic_test_func, + kwargs={'task1': task1, 'task2': task2}) + t.start() + t.join() if __name__ == "__main__": diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index e6ccadd4755f..0c0979763dba 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -28,6 +28,6 @@ from . import cortex_m7 import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") if use_auto_scheduler.lower() == "true": from ..ansor import * diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 7f37ba78a06c..d44fca8548d2 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -41,6 +41,6 @@ from .image import * import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") if use_auto_scheduler.lower() == "true": from ..ansor import * diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index a334397249e3..28e9e862f4d8 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -41,6 +41,6 @@ from .conv2d_alter_op import * import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") if use_auto_scheduler.lower() == "true": from ..ansor import * From 6126cdbefe7c30be19bc88c59af73e396161e81b Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 19 Jun 2020 14:47:19 +0800 Subject: [PATCH 27/78] Add tensorize step for loop_state (#31) * Add tensorize step --- python/tvm/ansor/loop_state.py | 25 +++++++- python/tvm/ansor/task_scheduler.py | 2 + src/ansor/compute_dag.cc | 5 +- src/ansor/loop_state.cc | 59 ++++++++++++++++--- src/ansor/loop_state.h | 20 +++++-- .../search_policy/meta_tile_rewrite_policy.cc | 20 ++++++- src/ansor/search_policy/utils.h | 10 ++++ src/ansor/serialization.cc | 16 ++++- src/ansor/transform_step.cc | 36 +++++++++++ src/ansor/transform_step.h | 43 +++++++++++--- .../python/unittest/test_ansor_loop_state.py | 38 ++++++++++++ 11 files changed, 246 insertions(+), 28 deletions(-) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 0cf157147423..67ec3ed12b05 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -411,14 +411,33 @@ def storage_align(self, stage_id, it, factor, offset): it : Iterator factor : Int offset : Int + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.clear_cache() + + def tensorize(self, stage_id, it, ti_func_name): + """ The `ti_func_name` corresponds to a global registered funcion + that returns a TensorIntrin + + Parameters + ---------- + stage_id : Int + The index of the stage to do storage align + it : Iterator + The target iterator + ti_func_name : Str + Tensorize intrinsic function name Returns ------- - state : State - The updated state + res_it : Iterator + The tensorized Iterator """ - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.state_object, res = _ffi_api.StateTensorize(self.state_object, + stage_id, it, + ti_func_name) self.clear_cache() + return res def __str__(self): return str(self.state_object) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index f8d3f419dcb4..89b4afd84e86 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -248,6 +248,8 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol else: raise ValueError("Invalid strategy: " + self.strategy) + if self.verbose >= 1: + print("Next tuning task: %d" % task_idx) self.tune_task(task_idx) def tune_task(self, task_idx): diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index de3b98a5106b..5ca0c8503662 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1086,7 +1086,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, iter->iter_type, iter->annotation, - &iter->ori_iters)); + &iter->ori_iters, + iter->attr)); } else { LOG(FATAL) << "Infer bound fails"; } @@ -1161,6 +1162,8 @@ std::pair > ComputeDAG::ReplaySteps( ps->ApplyToSchedule(stages, stage_to_axes, &schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 77361dbf837c..b6e6d854e3e5 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -39,7 +39,8 @@ TVM_REGISTER_NODE_TYPE(IteratorNode); // Maker for other classes Iterator IteratorNode::make(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { + const std::vector* ori_iters, + std::string attr) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); @@ -48,6 +49,7 @@ Iterator IteratorNode::make(std::string name, Range range, if (ori_iters != nullptr) { node->ori_iters = *ori_iters; } + node->attr = std::move(attr); return Iterator(node); } @@ -310,6 +312,15 @@ void State::storage_align(int stage_id, const Iterator& it, int factor, return DoStorageAlignStep(step); } +Iterator State::tensorize(int stage_id, const Iterator& it, + std::string ti_func_name) { + const Stage& stage = operator->()->stages[stage_id]; + TensorizeStep step = TensorizeStepNode::make( + stage_id, GetIndex(stage->iters, it), ti_func_name); + CopyOnWrite()->transform_steps.push_back(step); + return DoTensorizeStep(step); +} + // Steps' implementations void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -509,8 +520,10 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Iterator it = stage->iters[step->iter_id]; + CHECK_EQ(it->annotation, IteratorAnnotation::kNone); Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters); + step->annotation, &it->ori_iters, + it->attr); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -538,7 +551,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { new_iters.push_back(it); } else { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters, + it->attr)); } } @@ -559,7 +573,8 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters, + it->attr)); } // update attach map @@ -747,6 +762,18 @@ void State::DoStorageAlignStep(const StorageAlignStep& step) { stage->storage_offset = step->offset; } +Iterator State::DoTensorizeStep(const TensorizeStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Iterator it = stage->iters[step->iter_id]; + Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = std::move(new_stage); + return new_it; +} + void State::DoStep(const Step& step, const ComputeDAG& dag) { if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); @@ -776,6 +803,8 @@ void State::DoStep(const Step& step, const ComputeDAG& dag) { DoRfactorStep(GetRef(ps), dag); } else if (auto ps = step.as()) { DoStorageAlignStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoTensorizeStep(GetRef(ps)); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -854,15 +883,22 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, case kThreadY: *os << "gpu.threadIdx.y "; break; + case kTensorized: + *os << "tensorize "; + break; + default: + LOG(FATAL) << "Invalid Annotation " << iter->annotation; break; } if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")" - << "\n"; + << iter->range->extent << ")"; } else { - *os << iter->name << " (None)" - << "\n"; + *os << iter->name << " (None)"; } + if (!iter->attr.empty()) { + *os << " " << iter->attr; + } + *os << "\n"; indent += 2; } @@ -1174,6 +1210,13 @@ TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") return state; }); +TVM_REGISTER_GLOBAL("ansor.StateTensorize") +.set_body_typed([](State state, int stage_id, const Iterator& it, + std::string ti_func) { + const auto& res = state.tensorize(stage_id, it, ti_func); + return Array{state, res}; +}); + TVM_REGISTER_GLOBAL("ansor.StateEqual") .set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 90ba48cd92ac..6eef404ae272 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -74,7 +74,8 @@ enum IteratorType { /*! \brief The type of an iterator's annotation */ enum IteratorAnnotation { kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY + kVThread, kBlockX, kThreadX, kBlockY, kThreadY, + kTensorized }; class Iterator; @@ -90,14 +91,17 @@ class IteratorNode : public Object { IteratorType iter_type; IteratorAnnotation annotation; std::vector ori_iters; // The original iterators before fusion + std::string attr; static Iterator make(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); + const std::vector* ori_iters = nullptr, + std::string attr = ""); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); + v->Visit("attr", &attr); } static constexpr const char *_type_key = "ansor.Iterator"; @@ -115,6 +119,7 @@ class FuseStep; class AnnotationStep; class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; class CacheReadStep; class CacheWriteStep; class PragmaStep; class RfactorStep; class StorageAlignStep; +class TensorizeStep; /*! * \brief A stage in the compute declaration @@ -254,19 +259,21 @@ class State : public ObjectRef { Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + Iterator tensorize(int stage_id, const Iterator& it, + std::string ti_func_name); void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); int cache_write(int stage_id, const std::string& scope_name, const ComputeDAG& task_dag); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); /* Do transform steps * Note: The following functions only change loop state but do not change transform_history. @@ -278,14 +285,15 @@ class State : public ObjectRef { std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); Iterator DoAnnotationStep(const AnnotationStep& step); + Iterator DoTensorizeStep(const TensorizeStep& step); void DoComputeAtStep(const ComputeAtStep& step); void DoComputeRootStep(const ComputeRootStep& step); void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPragmaStep(const PragmaStep& step); + void DoStorageAlignStep(const StorageAlignStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); - void DoPragmaStep(const PragmaStep& step); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); - void DoStorageAlignStep(const StorageAlignStep& step); // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 4a045d31a487..7e022e3be3c3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -751,6 +751,11 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, continue; } + if (HasAnnotationIter(stage, IteratorAnnotation::kThreadX)) { + // Skip if this stage has already done thread bind + continue; + } + std::vector to_fuse; // This stage has not been tiled, but in GPU schedule, we must tile it @@ -861,10 +866,16 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, !HasCacheWriteStage((*state), stage_id - 1)) || (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && HasCacheWriteStage((*state), stage_id - 2))) { + const Stage& target_stage = (*state)->stages[stage_id]; + if (HasAnnotationIter(target_stage, IteratorAnnotation::kThreadX) || + HasAnnotationIter(target_stage, IteratorAnnotation::kTensorized)) { + // Skip if this stage has already done thread bind or has been + // tensorized + continue; + } // Get spatial_split_step_ids from the root stage std::unordered_set consumers; std::vector spatial_split_step_ids; - const Stage& target_stage = (*state)->stages[stage_id]; GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); CHECK_EQ(consumers.size(), 1); int target_stage_id = OperationToStage(*consumers.begin(), (*state)); @@ -1129,6 +1140,11 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, continue; } + if (HasAnnotationIter(stage, IteratorAnnotation::kTensorized)) { + // Skip if this stage has been tensorized + continue; + } + // try to fuse and vectorize the space iterators in the inner most tile int cum_length_prod = 1; @@ -1224,7 +1240,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, n--; } - } else if (stage->op->attrs.count(policy->always_unroll_key)) { + } else if (stage->op->attrs.count(policy->always_unroll_key)) { // Special unroll policy auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, policy->always_unroll_key); diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 3d0611173c94..472e90771879 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -143,6 +143,16 @@ inline bool HasReduceIter(const Stage& stage) { return false; } +// Return whether the stage has specific annotated iterators +inline bool HasAnnotationIter(const Stage& stage, IteratorAnnotation type) { + for (const auto& iter : stage->iters) { + if (iter->annotation == type) { + return true; + } + } + return false; +} + // Return whether an op needs multi level tiling inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index b03acb1edc3c..ed5d4b868c27 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -167,6 +167,11 @@ struct Handler > { writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->factor); writer->WriteArrayItem(ps->offset); + } else if (auto ps = data[i].as<::tvm::ansor::TensorizeStepNode>()) { + writer->WriteArrayItem(std::string("TS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->ti_func_name); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -179,7 +184,7 @@ struct Handler > { std::vector<::tvm::ansor::Step> * data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; - std::string name, scope_name, pragma_type; + std::string name, scope_name, pragma_type, ti_func_name; int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; int level, factor_iter_id, factor, offset; @@ -311,6 +316,15 @@ struct Handler > { reader->Read(&offset); data->push_back(::tvm::ansor::StorageAlignStepNode::make( stage_id, iter_id, factor, offset)); + } else if (name == "TS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&ti_func_name); + data->push_back(::tvm::ansor::TensorizeStepNode::make( + stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 3f59ff736e9d..b0e67a481ae3 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -26,6 +26,7 @@ #include "transform_step.h" #include +#include #include #include "utils.h" @@ -801,5 +802,40 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( return ss.str(); } +/********** Tensorize **********/ +TensorizeStep TensorizeStepNode::make(int stage_id, int iter_id, + std::string ti_func_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->ti_func_name = ti_func_name; + return TensorizeStep(node); +} + +void TensorizeStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + auto func = tvm::runtime::Registry::Get(ti_func_name); + CHECK(func != nullptr) << "Cannot find the tensorize intrinsic func"; + tvm::te::TensorIntrin res = (*func)(); + CHECK(res.defined()) << "Tensorize intrinsic func must return a " + << "tvm::te::TensorIntrin object"; + stage.tensorize(axes[iter_id], res); +} + +std::string TensorizeStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].tensorize(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << ti_func_name << "())\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 8240623ae3b1..9af14429bf61 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -23,17 +23,18 @@ * * \Note How to add a new transform step. * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in transform_steps.h, and implement its make function - * in FuseStepNode::make(...) transform_steps.cc - * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. - * - In these two functions you need to lower this step with tvm's schedule API - * 3. Implement State::fuse and State::DoFuseStep. + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its make function + * `FuseStepNode::make(...)` in `transform_steps.cc` + * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * - In these two functions you need to lower this step with tvm's te schedule API + * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style - * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` - * (in serialization.cc) + * in `serialization.cc` * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test */ #ifndef TVM_ANSOR_TRANSFORM_STEP_H_ @@ -365,6 +366,29 @@ class StorageAlignStepNode: public StepNode { }; TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); +/*! \brief Tensorize step that corresponds to te::Schedule::tensorize + * \Note This step takes a global registered function name as input. */ +class TensorizeStepNode: public StepNode { + public: + int iter_id; + std::string ti_func_name; + + static TensorizeStep make(int stage_id, int iter_id, + std::string ti_func_name); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.TensorizeStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); +}; +TVM_DEFINE_COW_OBJECT_REF(TensorizeStep, Step, TensorizeStepNode); + } // namespace ansor } // namespace tvm @@ -451,6 +475,11 @@ struct hash<::tvm::ansor::Step> { ::dmlc::HashCombine(std::hash()(ps->iter_id), ::dmlc::HashCombine(std::hash()(ps->factor), ps->offset)))); + } else if (auto ps = step.as<::tvm::ansor::TensorizeStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->ti_func_name))); } else { LOG(FATAL) << "Invalid step"; } diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 612d320036d8..a2c09aafc07b 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -17,6 +17,7 @@ """Test loop state and schedule primitives""" +import tvm from tvm import ansor, te import topi @@ -468,9 +469,46 @@ def test_rfactor(): " C.repl = ...\n" +@tvm._ffi.register_func +def test_intrin_gemv(): + m = 16 + l = 64 + a = te.placeholder((l,), name='a') + b = te.placeholder((l, m), name='b') + k = te.reduce_axis((0, l), name='k') + c = te.compute((m,), lambda i: te.sum(a[k] * b[k, i], axis=k), name='c') + Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", + offset_factor=1, strides=[1]) + Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", + offset_factor=1, strides=[te.var("s0"), 1]) + Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", + offset_factor=1, strides=[1]) + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + aa, bb = ins + cc = outs[0] + ib.emit(tvm.tir.call_extern("float32", "gemv_update", + cc.access_ptr("w"), + aa.access_ptr("r"), + bb.access_ptr("r"))) + return ib.get() + return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) + +def test_tensorize(): + dag = ansor.ComputeDAG(matmul_ansor_test(1024, 512, 64)) + s0 = dag.get_init_state() + C = 2 + + its = s0.split(C, s0.stages[C].iters[1], [16]) + s0.tensorize(C, its[1], "test_intrin_gemv") + + sch, tensors = dag.apply_steps_from_state(s0) + tvm.lower(sch, tensors, simple_mode=True) + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_follow_split_follow_fused_split() test_compute_at_root_inline() test_cache_read_write() test_rfactor() + test_tensorize() From c7364df568922d1643d50b85f5e0c3fa3acb64d2 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 19 Jun 2020 18:19:39 +0800 Subject: [PATCH 28/78] State python api update (#33) * Start to update api * Add compute_dag to state * API update --- python/tvm/ansor/compute_dag.py | 6 +- python/tvm/ansor/loop_state.py | 177 +++++++++-- python/tvm/ansor/serialization.py | 2 +- src/ansor/loop_state.cc | 4 - tests/python/unittest/test_ansor_common.py | 29 +- .../python/unittest/test_ansor_compute_dag.py | 19 +- tests/python/unittest/test_ansor_feature.py | 4 +- .../python/unittest/test_ansor_loop_state.py | 275 ++++++++++++------ .../unittest/test_ansor_search_policy.py | 11 +- ...t_ansor_vectorized_cooperative_fetching.py | 152 ---------- 10 files changed, 374 insertions(+), 305 deletions(-) delete mode 100644 tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 23ba1b32f5c4..6d82942aa744 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -49,7 +49,7 @@ def get_init_state(self): ------- state : State """ - return State(_ffi_api.ComputeDAGGetInitState(self)) + return State(_ffi_api.ComputeDAGGetInitState(self), self) def apply_steps_from_state(self, state, layout_rewrite_level=None): """ @@ -98,9 +98,9 @@ def infer_bound_from_state(self, state): state : StateObject """ if isinstance(state, State): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object)) + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object), self) elif isinstance(state, StateObject): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state)) + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state), self) else: raise ValueError("The input must be a State or StateObject") diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 67ec3ed12b05..23289c027293 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -77,16 +77,48 @@ class State: ----- This is a wrapper class of StateObject to deal with copy-on-write property """ - def __init__(self, state_object): + def __init__(self, state_object, dag): self.state_object = state_object + self.compute_dag = dag self.stages_cache = None + self.stage_id_map = {} + self.__update_tensor_stage_map() + + def __getitem__(self, k): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + if isinstance(k, tvm.te.Tensor): + return self.stages_cache[self.stage_id_map[k.op]] + else: + raise ValueError("Item must be Tensor") + + def __update_tensor_stage_map(self): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + for index, stage in enumerate(self.stages_cache): + self.stage_id_map[stage.op] = index + + def __insert_new_stage(self, new_stage_id): + new_stage_id = int(new_stage_id) + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + added_stage_tensor = self.stages_cache[new_stage_id].op.output(0) + + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_stage_tensor.op] = new_stage_id + self.__update_tensor_stage_map() + + return added_stage_tensor def clear_cache(self): self.stages_cache = None def copy(self): - return State(self.state_object) + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state @property def stages(self): @@ -99,6 +131,17 @@ def stages(self): self.stages_cache = _ffi_api.StateGetStages(self.state_object) return self.stages_cache + @property + def stage_tensors(self): + """ + Returns + ------- + Tensor + """ + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + return [stage.op.output(0) for stage in self.stages_cache] + def transform_steps_size(self): """ Return the size of transform_steps """ @@ -113,6 +156,11 @@ def reorder(self, stage_id, order): order : List[Iterator] Iterators in the expected order """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) self.clear_cache() @@ -135,6 +183,11 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, it, lengths, inner_to_outer) self.clear_cache() @@ -158,6 +211,11 @@ def follow_split(self, stage_id, it, src_step_id, n_split): res_its : List[Iterator] The splitted new Iterators """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, it, src_step_id, n_split) self.clear_cache() @@ -185,6 +243,11 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, res_its : List[Iterator] The splitted new Iterators """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, it, src_step_ids, level, factor_or_nparts) @@ -205,6 +268,11 @@ def fuse(self, stage_id, iters): res_it : Iterator The fused Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) self.clear_cache() return res @@ -223,6 +291,11 @@ def vectorize(self, stage_id, it): res_it : Iterator The vectorized Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, it) self.clear_cache() return res @@ -241,6 +314,11 @@ def parallel(self, stage_id, it): res_it : Iterator The parallelized Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, it) self.clear_cache() return res @@ -261,6 +339,11 @@ def unroll(self, stage_id, it, max_unroll=-1): res_it : Iterator The unrolled Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, max_unroll) self.clear_cache() return res @@ -290,6 +373,11 @@ def bind_thread(self, stage_id, it, thread_name): } thread_id = trans_table[thread_name] + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id) self.clear_cache() return res @@ -305,6 +393,15 @@ def compute_at(self, stage_id, target_stage_id, target_iter): target_iter : Iterator The target Iterator of compute_at """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + if isinstance(target_stage_id, tvm.te.Tensor): + target_stage_id = self.stage_id_map[target_stage_id.op] + elif not isinstance(target_stage_id, int): + raise ValueError("target_stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, target_stage_id, target_iter) self.clear_cache() @@ -316,6 +413,11 @@ def compute_root(self, stage_id): stage_id : Int The index of the stage to compute root """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) self.clear_cache() @@ -326,10 +428,15 @@ def compute_inline(self, stage_id): stage_id : Int The index of the stage to compute inline """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) self.clear_cache() - def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + def cache_read(self, stage_id, scope_name, reader_stage_ids): """ Parameters ---------- @@ -337,37 +444,55 @@ def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): The index of the stage to do cache_read scope_name : Str reader_stage_ids : List[Int] - task_dag : ComputeDAG Returns ------- new_stage_id : Int The added staged id """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + if isinstance(reader_stage_ids, list): + tmp_list = [] + for reader_stage_id in reader_stage_ids: + if isinstance(reader_stage_id, tvm.te.Tensor): + tmp_list.append(self.stage_id_map[reader_stage_id.op]) + elif isinstance(reader_stage_id, int): + tmp_list.append(reader_stage_id) + else: + raise ValueError("reader_stage_id must be Tensor or Int") + reader_stage_ids = tmp_list + else: + raise ValueError("reader_stage_ids must be list of Tensor or Int") + self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, scope_name, reader_stage_ids, - task_dag) - self.clear_cache() - return int(new_stage_id) + self.compute_dag) + return self.__insert_new_stage(new_stage_id) - def cache_write(self, stage_id, scope_name, task_dag): + def cache_write(self, stage_id, scope_name): """ Parameters ---------- stage_id : Int The index of the stage to do cache read scope_name : Str - task_dag : ComputeDAG Returns ------- new_stage_id : Int The added staged id """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, - scope_name, task_dag) - self.clear_cache() - return int(new_stage_id) + scope_name, self.compute_dag) + return self.__insert_new_stage(new_stage_id) def pragma(self, stage_id, it, pragma_type): """ @@ -379,10 +504,15 @@ def pragma(self, stage_id, it, pragma_type): The iterator to add pragma pragma_type : Str """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, it, pragma_type) self.clear_cache() - def rfactor(self, stage_id, it, factor_iter_id, task_dag): + def rfactor(self, stage_id, it, factor_iter_id): """ Parameters ---------- @@ -390,17 +520,20 @@ def rfactor(self, stage_id, it, factor_iter_id, task_dag): The index of the stage to do reduction factor it : Iterator factor_iter_id : Int - task_dag : ComputeDAG Returns ------- new_stage_id : Int The added staged id """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, it, - factor_iter_id, task_dag) - self.clear_cache() - return int(new_stage_id) + factor_iter_id, self.compute_dag) + return self.__insert_new_stage(new_stage_id) def storage_align(self, stage_id, it, factor, offset): """ @@ -412,6 +545,11 @@ def storage_align(self, stage_id, it, factor, offset): factor : Int offset : Int """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) self.clear_cache() @@ -433,6 +571,11 @@ def tensorize(self, stage_id, it, ti_func_name): res_it : Iterator The tensorized Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateTensorize(self.state_object, stage_id, it, ti_func_name) diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index e11a589a7522..d9b8a2f5c075 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -76,7 +76,7 @@ def write_measure_records_to_file(filename, inputs, results): def get_states_from_measure_inputs(inputs, task): """Get states from measure inputs""" state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task) - return [State(s) for s in state_objects] + return [State(s, task.compute_dag) for s in state_objects] def best_measure_pair_in_file(filename, workload_key=None, target=None): diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index b6e6d854e3e5..7569c91e3368 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1063,10 +1063,6 @@ TVM_REGISTER_GLOBAL("ansor.StateGetStages").set_body_typed([](const State& state return Array(state->stages); }); -TVM_REGISTER_GLOBAL("ansor.StateGetStage").set_body_typed([](const State& state, int index) { - return state->stages[index]; -}); - TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const State& state) { return static_cast(state->transform_steps.size()); }); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index e23dba2aa4e3..083bd2721cb6 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -56,26 +56,19 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation def get_tiled_matmul(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - A, B, C = 0, 1, 2 - C_global = s0.cache_write(C, "global", dag) - C += 1 - its0 = s0.split(C, s0.stages[C].iters[0], [4, 8, 8]) - its1 = s0.split(C, s0.stages[C].iters[4], [8, 4, 4]) + C_global = s0.cache_write(C, "global") + its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) + its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3]]) - s0.compute_at(C_global, C, s0.stages[C].iters[3]) - s0.split(C_global, s0.stages[C_global].iters[2], [16]) - B_global = s0.cache_read(B, "global", [C_global], dag) - C += 1 - C_global += 1 - s0.compute_at(B_global, C_global, s0.stages[C_global].iters[0]) - A_global = s0.cache_read(A, "global", [C_global], dag) - B += 1 - B_global += 1 - C += 1 - C_global += 1 - s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) + s0.compute_at(C_global, C, s0[C].iters[3]) + s0.split(C_global, s0[C_global].iters[2], [16]) + B_global = s0.cache_read(B, "global", [C_global]) + s0.compute_at(B_global, C_global, s0[C_global].iters[0]) + A_global = s0.cache_read(A, "global", [C_global]) + s0.compute_at(A_global, C_global, s0[C_global].iters[2]) return dag, s0 diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index b8afcc6a5b23..313dc1f89902 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,12 +34,14 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global, B_global, C_global = 1, 3, 4 - assert s.stages[B_global].iters[0].range.extent == 512 - assert s.stages[B_global].iters[1].range.extent == 16 - assert s.stages[A_global].iters[0].range.extent == 1 - assert s.stages[A_global].iters[1].range.extent == 16 - assert s.stages[C_global].iters[0].range.extent == 64 + A_global = s.stage_tensors[1] + B_global = s.stage_tensors[3] + C_global = s.stage_tensors[4] + assert s[B_global].iters[0].range.extent == 512 + assert s[B_global].iters[1].range.extent == 16 + assert s[A_global].iters[0].range.extent == 1 + assert s[A_global].iters[1].range.extent == 16 + assert s[C_global].iters[0].range.extent == 64 def test_estimate_flop(): @@ -57,9 +59,8 @@ def test_lower_legalize_invalid_attach(): dag = ansor.ComputeDAG([A, B]) s = dag.get_init_state() - A, B = 0, 1 - s.compute_at(A, B, s.stages[B].iters[1]) - s.split(B, s.stages[B].iters[1], [2]) + s.compute_at(A, B, s[B].iters[1]) + s.split(B, s[B].iters[1], [2]) sch, tensors = dag.apply_steps_from_state(s) stmt = tvm.lower(sch, tensors, simple_mode=True) diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index 567fc080c6f8..bb19b84a970d 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -33,9 +33,9 @@ def fequal(a, b): def test_cpu_matmul(): dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() - C = 2 + C = s.stage_tensors[2] - i, j, k = s.stages[C].iters + i, j, k = s[C].iters io, ii = s.split(C, i, [16]) jo, ji = s.split(C, j, [8]) s.reorder(C, [io, jo, k, ji, ii]) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index a2c09aafc07b..87688e276469 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -17,6 +17,8 @@ """Test loop state and schedule primitives""" +import numpy as np + import tvm from tvm import ansor, te import topi @@ -25,16 +27,16 @@ def test_split_fuse_reorder_annotation(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - i, j, k = s0.stages[C].iters + i, j, k = s0[C].iters assert i.range.extent == 512 io, ii = s0.split(C, i, [16]) - assert s0.stages[C].iters[0] == io - assert s0.stages[C].iters[1] == ii + assert s0[C].iters[0] == io + assert s0[C].iters[1] == ii assert io.range.extent == 32 assert ii.range.extent == 16 @@ -43,21 +45,21 @@ def test_split_fuse_reorder_annotation(): assert ji.range.extent == 8 s0.reorder(C, [io, jo, k, ji, ii]) - assert s0.stages[C].iters[2].range.extent == 512 + assert s0[C].iters[2].range.extent == 512 fused_it = s0.fuse(C, [io, jo]) assert fused_it.range.extent == 2048 s1 = dag.get_init_state() - i, j, _ = s1.stages[C].iters + i, j, _ = s1[C].iters i1, i2, i3 = s1.split(C, i, [8, 2]) j1, j2, j3 = s1.split(C, j, [32, 8], False) - assert s1.stages[C].iters[0].range.extent == 32 - assert s1.stages[C].iters[1].range.extent == 8 - assert s1.stages[C].iters[2].range.extent == 2 - assert s1.stages[C].iters[3].range.extent == 32 - assert s1.stages[C].iters[4].range.extent == 8 - assert s1.stages[C].iters[5].range.extent == 2 + assert s1[C].iters[0].range.extent == 32 + assert s1[C].iters[1].range.extent == 8 + assert s1[C].iters[2].range.extent == 2 + assert s1[C].iters[3].range.extent == 32 + assert s1[C].iters[4].range.extent == 8 + assert s1[C].iters[5].range.extent == 2 s1.parallel(C, j1) s1.unroll(C, j2) @@ -68,23 +70,22 @@ def test_split_fuse_reorder_annotation(): def test_follow_split_follow_fused_split(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - C_global = s0.cache_write(C, "global", dag) - C += 1 + C_global = s0.cache_write(C, "global") - its0 = s0.split(C, s0.stages[C].iters[0], [4, 2, 8, 4], True) + its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) split_step0 = s0.transform_steps_size() - 1 for level in range(1, 6): tmp = s0.copy() - tmp.follow_split(C_global, tmp.stages[C_global].iters[0], split_step0, level) + tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) for i in range(0, level): - assert tmp.stages[C].iters[i].range.extent == \ - tmp.stages[C_global].iters[i].range.extent + assert tmp[C].iters[i].range.extent == \ + tmp[C_global].iters[i].range.extent - its1 = s0.split(C, s0.stages[C].iters[5], [2, 2, 4, 8]) + its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) split_step1 = s0.transform_steps_size() - 1 its = [] for i0, i1 in zip(its0, its1): @@ -92,40 +93,41 @@ def test_follow_split_follow_fused_split(): its.append(i1) s0.reorder(C, its) for i in range(0, 5): - s0.fuse(C, [s0.stages[C].iters[i], s0.stages[C].iters[i + 1]]) + s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) for level in range(0, 4): tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, False) - assert tmp.stages[C].iters[level + 1].range.extent == \ - tmp.stages[C_global].iters[0].range.extent + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[0].range.extent for level in range(0, 4): tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, True) - assert tmp.stages[C].iters[level + 1].range.extent == \ - tmp.stages[C_global].iters[1].range.extent + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[1].range.extent def test_compute_at_root_inline(): dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 - conv = 3 + conv = s0.stage_tensors[3] # bias = 4 - bias_add = 5 + bias_add = s0.stage_tensors[5] # bn_scale = 6 - bn_mul = 7 + bn_mul = s0.stage_tensors[7] # bn_offset = 8 - bn_add, relu = 9, 10 + bn_add = s0.stage_tensors[9] + relu = s0.stage_tensors[10] - s0 = dag.get_init_state() s0.compute_inline(bn_add) s0.compute_inline(bn_mul) s0.compute_inline(bias_add) - s0.compute_at(conv, relu, s0.stages[relu].iters[2]) + s0.compute_at(conv, relu, s0[relu].iters[2]) assert str(s0) == \ "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ "for i1 (0,3)\n" + \ @@ -186,33 +188,27 @@ def test_cache_read_write(): name='Kernel') conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) relu = topi.nn.relu(conv) - out = topi.add(data, relu) + add = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, add]) + s0 = dag.get_init_state() - dag = ansor.ComputeDAG([data, kernel_data, out]) - data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + pad_temp = s0.stage_tensors[1] + kernel_split = s0.stage_tensors[3] # 0: init state - s0 = dag.get_init_state() - ori_its = s0.stages[add].iters - its = s0.split(add, s0.stages[add].iters[0], [2]) + ori_its = s0[add].iters + its = s0.split(add, s0[add].iters[0], [2]) s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0.compute_inline(relu) # 1: simple cache_write with compute_at - conv_global = s0.cache_write(conv, "global", dag) - conv += 1 - relu += 1 - add += 1 - s0.compute_at(conv_global, conv, s0.stages[conv].iters[3]) + conv_global = s0.cache_write(conv, "global") + s0.compute_at(conv_global, conv, s0[conv].iters[3]) # 2: simple cache_read with compute_at - kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0.compute_at(kernel_global, conv_global, - s0.stages[conv_global].iters[4]) + kernel_global = s0.cache_read(kernel, "global", [conv_global]) + s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for i0 (0,4)\n" + \ @@ -257,41 +253,14 @@ def test_cache_read_write(): # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory - pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0.compute_at(pad_temp_global, conv_global, s0.stages[conv_global].iters[2]) - s0.compute_at(pad_temp_shared, conv_global, s0.stages[conv_global].iters[4]) + pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) + pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) + s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) + s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) # 4: cache_read with multi readers # This stage cannot be compute at to its consumer - data_global = s0.cache_read(data, "global", [pad_temp, add], dag) - pad_temp += 1 - pad_temp_global += 1 - pad_temp_shared += 1 - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 + s0.cache_read(data, "global", [pad_temp, add]) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -364,7 +333,7 @@ def test_cache_read_write(): # Seems there's bug with the input/output tensor. Such multi outputs case # should be unusual, so we make some hack on DoCacheWrite # To be fixed in the future - s0.cache_write(kernel_split, "global", dag) + s0.cache_write(kernel_split, "global") assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -434,14 +403,14 @@ def test_cache_read_write(): def test_rfactor(): - dag = ansor.ComputeDAG(matmul_ansor_test(8, 8, 512)) + A, B, C = matmul_ansor_test(8, 8, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - ko, ki = s0.split(C, s0.stages[C].iters[2], [16]) + ko, ki = s0.split(C, s0[C].iters[2], [16]) s1 = s0.copy() - s1.rfactor(C, ko, 2, dag) + s1.rfactor(C, ko, 2) assert str(s1) == \ "Placeholder: A, B\n" + \ "for i (0,8)\n" + \ @@ -455,7 +424,7 @@ def test_rfactor(): " C.repl = ...\n" s2 = s0.copy() - s2.rfactor(C, ki, 2, dag) + s2.rfactor(C, ki, 2) assert str(s2) == \ "Placeholder: A, B\n" + \ "for i (0,8)\n" + \ @@ -469,6 +438,122 @@ def test_rfactor(): " C.repl = ...\n" +def vcf_init_common(): + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + B_shared = s0.cache_read(B, "shared", [C]) + B_local = s0.cache_read(B_shared, "local", [C]) + A_shared = s0.cache_read(A, "shared", [C]) + A_local = s0.cache_read(A_shared, "local", [C]) + + return A_shared, A_local, B_shared, B_local, C, dag, s0 + + +def vcf_check_common(dag, state): + s, args = dag.apply_steps_from_state(state) + # To check if every vectorize loop transforms to ramp expr successfully + # TODO(jcf94): Find a better way to process the check in AST + print(tvm.lower(s, args)) + + if tvm.context("cuda", 0).exist: + tgt = tvm.target.cuda() + mod = tvm.build(s, args, tgt) + # To check if every vectorize loop transforms to correct instruction + print(mod.imported_modules[0].get_source()) + + ctx = tvm.context("cuda", 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + else: + print("CUDA device not found, skip this test.") + + +def test_vectorized_cooperative_fetching_x(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() + + its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) + s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) + s0.bind_thread(C, s0[C].iters[1], "vthread") + s0.fuse(C, [s0[C].iters[2], s0[C].iters[3]]) + s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0[C].iters[3]) + fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [64, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.vectorize(B_shared, its[2]) + s0.compute_at(B_local, C, s0[C].iters[4]) + fused_it = s0.fuse(B_local, s0[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0[C].iters[3]) + fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [64, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.vectorize(A_shared, its[2]) + s0.compute_at(A_local, C, s0[C].iters[4]) + fused_it = s0.fuse(A_local, s0[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + vcf_check_common(dag, s0) + + +def test_vectorized_cooperative_fetching_xy(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() + + its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) + s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) + s0.bind_thread(C, s0[C].iters[1], "vthread") + s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") + s0.bind_thread(C, s0[C].iters[3], "threadIdx.y") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0[C].iters[4]) + fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [8, 8, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.bind_thread(B_shared, its[2], "threadIdx.y") + s0.vectorize(B_shared, its[3]) + s0.compute_at(B_local, C, s0[C].iters[5]) + fused_it = s0.fuse(B_local, s0[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0[C].iters[4]) + fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [8, 8, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.bind_thread(A_shared, its[2], "threadIdx.y") + s0.vectorize(A_shared, its[3]) + s0.compute_at(A_local, C, s0[C].iters[5]) + fused_it = s0.fuse(A_local, s0[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + vcf_check_common(dag, s0) + + @tvm._ffi.register_func def test_intrin_gemv(): m = 16 @@ -495,11 +580,11 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) def test_tensorize(): - dag = ansor.ComputeDAG(matmul_ansor_test(1024, 512, 64)) + A, B, C = matmul_ansor_test(1024, 512, 64) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - its = s0.split(C, s0.stages[C].iters[1], [16]) + its = s0.split(C, s0[C].iters[1], [16]) s0.tensorize(C, its[1], "test_intrin_gemv") sch, tensors = dag.apply_steps_from_state(s0) @@ -511,4 +596,6 @@ def test_tensorize(): test_compute_at_root_inline() test_cache_read_write() test_rfactor() + test_vectorized_cooperative_fetching_x() + test_vectorized_cooperative_fetching_xy() test_tensorize() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 839992c67e0f..9b1716175b5a 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -80,6 +80,7 @@ def test_search_basic(): t.start() t.join() + def test_search_xgb_model_rpc_runner(): measure_ctx = ansor.LocalRPCMeasureContext() search_common(seed=456787236, cost_model=ansor.XGBModel(), @@ -123,13 +124,13 @@ def apply_func1(meta_policy, state, stage_id): # Stage by stage way ret = [] if stage_id == 2: - state = ansor.loop_state.State(state) + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) state.split(2, state.stages[2].iters[0], [4, 4]) state.split(2, state.stages[2].iters[3], [4, 4]) ret.append([state.state_object, stage_id - 1]) elif stage_id == 1: - state = ansor.loop_state.State(state) - state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) + state.cache_read(1, "global", [2]) state.compute_at(2, 3, state.stages[3].iters[4]) ret.append([state.state_object, stage_id - 1]) else: @@ -139,11 +140,11 @@ def apply_func1(meta_policy, state, stage_id): def apply_func2(meta_policy, state, stage_id): # More template like way ret = [] - state = ansor.loop_state.State(state) + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) state.split(2, state.stages[2].iters[0], [4, 4]) state.split(2, state.stages[2].iters[3], [4, 4]) - state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state.cache_read(1, "global", [2]) state.compute_at(2, 3, state.stages[3].iters[4]) ret.append([state.state_object, -1]) diff --git a/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py b/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py deleted file mode 100644 index c41abc7bcb3d..000000000000 --- a/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py +++ /dev/null @@ -1,152 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Test for vectorized cooperative fetching """ - -import numpy as np -import tvm -from tvm import ansor, te -import topi - -from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu - - -def init_common(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) - s0 = dag.get_init_state() - A, B, C = 0, 1, 2 - B_shared = s0.cache_read(B, "shared", [C], dag) - C += 1 - B_local = s0.cache_read(B_shared, "local", [C], dag) - C += 1 - A_shared = s0.cache_read(A, "shared", [C], dag) - B += 1 - B_shared += 1 - B_local += 1 - C += 1 - A_local = s0.cache_read(A_shared, "local", [C], dag) - B += 1 - B_shared += 1 - B_local += 1 - C += 1 - - return A_shared, A_local, B_shared, B_local, C, dag, s0 - -def check_common(dag, state): - s, args = dag.apply_steps_from_state(state) - # To check if every vectorize loop transforms to ramp expr successfully - # TODO(jcf94): Find a better way to process the check in AST - print(tvm.lower(s, args)) - - if tvm.context("cuda", 0).exist: - tgt = tvm.target.cuda() - mod = tvm.build(s, args, tgt) - # To check if every vectorize loop transforms to correct instruction - print(mod.imported_modules[0].get_source()) - - ctx = tvm.context("cuda", 0) - dtype = dag.tensors[0].dtype - a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - else: - print("CUDA device not found, skip this test.") - -def test_vectorized_cooperative_fetching_x(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() - - its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) - s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) - s0.bind_thread(C, s0.stages[C].iters[1], "vthread") - s0.fuse(C, [s0.stages[C].iters[2], s0.stages[C].iters[3]]) - s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0.stages[C].iters[3]) - fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [64, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.vectorize(B_shared, its[2]) - s0.compute_at(B_local, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0.stages[C].iters[3]) - fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [64, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.vectorize(A_shared, its[2]) - s0.compute_at(A_local, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - check_common(dag, s0) - -def test_vectorized_cooperative_fetching_xy(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() - - its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) - s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) - s0.bind_thread(C, s0.stages[C].iters[1], "vthread") - s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") - s0.bind_thread(C, s0.stages[C].iters[3], "threadIdx.y") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [8, 8, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.bind_thread(B_shared, its[2], "threadIdx.y") - s0.vectorize(B_shared, its[3]) - s0.compute_at(B_local, C, s0.stages[C].iters[5]) - fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [8, 8, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.bind_thread(A_shared, its[2], "threadIdx.y") - s0.vectorize(A_shared, its[3]) - s0.compute_at(A_local, C, s0.stages[C].iters[5]) - fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - check_common(dag, s0) - -if __name__ == "__main__": - test_vectorized_cooperative_fetching_x() - test_vectorized_cooperative_fetching_xy() From 36cd9ef474664490c9736c43282912df4c48c257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Minmin=20Sun=20=28=E5=AD=99=E6=95=8F=E6=95=8F=29?= Date: Fri, 19 Jun 2020 18:24:30 +0800 Subject: [PATCH 29/78] kernel layout rewrite (#28) * kernel layout rewrite * remove some hacks * add defuse_ops pass and move kernel_layout_rewrite pass after fuse_ops pass * set TVM_RELAY_DISABLE_BUILD_CACHE for task extraction and prepare_layout_rewrite --- include/tvm/relay/attrs/transform.h | 13 + include/tvm/relay/transform.h | 14 + python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/compute_dag.py | 9 +- python/tvm/ansor/measure.py | 1 - python/tvm/ansor/relay_integration.py | 7 +- python/tvm/ansor/topi_integration.py | 13 +- python/tvm/relay/op/_transform.py | 2 + python/tvm/relay/op/op_attrs.py | 3 + python/tvm/relay/op/transform.py | 21 + python/tvm/relay/testing/dqn.py | 25 +- python/tvm/relay/testing/resnet.py | 4 + python/tvm/te/tensor.py | 6 +- scripts/tune_network.py | 9 +- src/ansor/compute_dag.cc | 725 ++++++++++-------- src/ansor/compute_dag.h | 2 +- src/relay/analysis/type_solver.cc | 1 + src/relay/backend/build_module.cc | 13 + src/relay/backend/compile_engine.cc | 5 + src/relay/backend/compile_engine.h | 3 + src/relay/transforms/defuse_ops.cc | 98 +++ .../transforms/kernel_layout_transform.cc | 63 ++ .../transforms/kernel_layout_transform.h | 75 ++ src/relay/transforms/pattern_util.h | 2 + topi/python/topi/nn/conv2d.py | 24 +- 25 files changed, 787 insertions(+), 353 deletions(-) create mode 100644 src/relay/transforms/defuse_ops.cc create mode 100644 src/relay/transforms/kernel_layout_transform.cc create mode 100644 src/relay/transforms/kernel_layout_transform.h diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 750a8a43163c..95476ed61bdd 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -296,6 +296,19 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for KernelLayoutTransform operator */ +struct KernelLayoutTransformAttrs : public tvm::AttrsNode { + std::string src_layout; + std::string dst_layout; + + TVM_DECLARE_ATTRS(KernelLayoutTransformAttrs, "relay.attrs.KernelLayoutTransformAttrs") { + TVM_ATTR_FIELD(src_layout) + .describe("The source layout of the tensor. (e.g. 1N32C112H112W)"); + TVM_ATTR_FIELD(dst_layout) + .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)"); + } +}; + /*! \brief Attributes for ShapeOf operator */ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1b8b31aee5d1..5f5d9b643633 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -277,6 +277,20 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); +/*! + * \brief Alternate the layouts of kernels. + * + * \return The pass. + */ +TVM_DLL Pass KernelLayoutTransform(); + +/*! + * \brief The reverse of FuseOps. + * + * \return The pass. + */ +TVM_DLL Pass DeFuseOps(); + /*! * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 6ea8a0ce904f..b43b21a60144 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -44,4 +44,4 @@ FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext from .topi_integration import register_topi_schedule, TaskExtractEnv from .relay_integration import extract_from_program, extract_from_multiple_program, \ - finish_layout_rewrite + finish_layout_rewrite, prepare_layout_rewrite diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 6d82942aa744..c54c14ec123a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -64,12 +64,17 @@ def apply_steps_from_state(self, state, layout_rewrite_level=None): args : List[Tensor] """ if isinstance(state, State): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object) + return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object, + layout_rewrite_level) elif isinstance(state, StateObject): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state) + return _ffi_api.ComputeDAGApplyStepsFromState(self, state, + layout_rewrite_level) else: raise ValueError("The input must be a State or StateObject") + def rewrite_layout_from_state(self, state: State): + return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) + def print_python_code_from_state(self, state): """ Parameters diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b82327ec67c4..8b38f91647b2 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -534,4 +534,3 @@ def timed_func(inp, build_res): print("") return measure_results - diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 348828eec4b4..383471ee060d 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -54,7 +54,7 @@ def _lower(mod, # If failed to compile, then fallback to use VM compiler. # TODO: Currently VM compiler is likely to stack overflow for large models. try: - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): opt_mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(opt_mod["main"]) @@ -191,7 +191,7 @@ def prepare_layout_rewrite(mod, params, ops, target): """Prepare for kernel layout rewrite. This function will write layout infos to a global static variable, then these layout info will be used by a relay pass `kernel_layout_transform`. """ - from .. import relay + from tvm import relay env = TaskExtractEnv.get(do_layout_rewrite=True) @@ -203,9 +203,8 @@ def prepare_layout_rewrite(mod, params, ops, target): else: warnings.warn("Op %s is not tunable, ignored." % op_name) + env.reset(topi_scheds) with env: - env.reset(topi_scheds) - # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, params)) diff --git a/python/tvm/ansor/topi_integration.py b/python/tvm/ansor/topi_integration.py index b4c15f74ea44..77def00cf9ec 100644 --- a/python/tvm/ansor/topi_integration.py +++ b/python/tvm/ansor/topi_integration.py @@ -26,14 +26,17 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ +import os +import json import tvm.te._ffi_api from tvm import target as _target from tvm.te import tensor from tvm.te.tensor import PlaceholderOp, ComputeOp -from .dispatcher import DispatchContext +from .dispatcher import DispatchContext, BlockingEmptyContext from .workload_registry import register_auto_scheduler_workload_bufs, \ make_workload_key_bufs, compute_dag_hash +from .compute_dag import ComputeDAG def traverse_to_get_io_tensors(outs): layout_free_ops = [] @@ -77,11 +80,14 @@ def __init__(self, do_layout_rewrite=False): def __enter__(self): self.tracing = True self.wkl_key_collection = {} + self.relay_disable_build_cache_ = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" return self def __exit__(self, exc_type, exc_val, exc_tb): self.tracing = False + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache_ def reset(self, wanted_relay_ops=None): """Reset task collections @@ -144,7 +150,7 @@ def get(do_layout_rewrite=False): The single instance of TaskExtractEnv """ if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv() + TaskExtractEnv.current = TaskExtractEnv(do_layout_rewrite) else: TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite return TaskExtractEnv.current @@ -188,7 +194,7 @@ def wrapper(outs, *args, **kwargs): # Rewrite the dag and update the transform history for # the new dag in DispatchContext dispatch_ctx = DispatchContext.current - tgt = _target.current_target() + tgt = _target.Target.current() state = dispatch_ctx.query(tgt, key) dag = ComputeDAG(outs) new_dag = dag.rewrite_layout_from_state(state) @@ -199,7 +205,6 @@ def wrapper(outs, *args, **kwargs): task_env.layout_rewrite_success_ct += 1 # Call schedule_func under FallbackContext() to avoid layout rewrite - tgt = _target.Target.current() cfg = BlockingEmptyContext().query(tgt, key) return topi_schedule(cfg, outs) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d104c1b1c2f8..41bd10cabe3e 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -74,6 +74,8 @@ def compute_strided_set(attrs, inputs, output_type): # layout_transform _reg.register_injective_schedule("layout_transform") _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) +_reg.register_injective_schedule("kernel_layout_transform") +_reg.register_pattern("kernel_layout_transform", OpPattern.INJECTIVE) # argwhere @_reg.register_compute("argwhere") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 486d63c36ff0..58b9269a4c48 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -261,6 +261,9 @@ class ClipAttrs(Attrs): class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" +@tvm._ffi.register_object("relay.attrs.KernelLayoutTransformAttrs") +class KernelLayoutTransformAttrs(Attrs): + """Attributes for transform.kernel_layout_transform""" @tvm._ffi.register_object("relay.attrs.ShapeOfAttrs") class ShapeOfAttrs(Attrs): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index a37226ea4f58..f2fa2b5f5b90 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -815,6 +815,27 @@ def layout_transform(data, src_layout, dst_layout): """ return _make.layout_transform(data, src_layout, dst_layout) +def kernel_layout_transform(data, src_layout, dst_layout): + """Transform the layout of a kernel + + Parameters + ---------- + data : relay.Expr + The source tensor to be transformed + + src_layout: str + The source layout. (e.g 1N32C112H112W) + + dst_layout: str + The destination layout. (e.g. 1N2C112H112W16c) + + Returns + ------- + ret : relay.Expr + The transformed tensor. + """ + return _make.kernel_layout_transform(data, src_layout, dst_layout) + def reverse_reshape(data, newshape): """Reshapes the input array where the special values are inferred from diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 10da37001f12..b65e0ad5cae9 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -26,27 +26,32 @@ from . import layers from .init import create_workload -def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): +def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): """get symbol of nature dqn""" data_shape = (batch_size,) + image_shape data = relay.var("data", shape=data_shape, dtype=dtype) + bias_axis = layout.index('C') + conv1_bias = relay.var("conv1_bias") conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), - channels=32, name="conv1") - conv1 = relay.nn.bias_add(conv1, conv1_bias) + channels=32, name="conv1", data_layout=layout, + kernel_layout=layers.conv_kernel_layout(layout)) + conv1 = relay.nn.bias_add(conv1, conv1_bias, bias_axis) relu1 = relay.nn.relu(conv1) conv2_bias = relay.var("conv2_bias") conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), - channels=64, name="conv2") - conv2 = relay.nn.bias_add(conv2, conv2_bias) + channels=64, name="conv2", data_layout=layout, + kernel_layout=layers.conv_kernel_layout(layout)) + conv2 = relay.nn.bias_add(conv2, conv2_bias, bias_axis) relu2 = relay.nn.relu(conv2) conv3_bias = relay.var("conv3_bias") conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), - channels=64, name="conv3") - conv3 = relay.nn.bias_add(conv3, conv3_bias) + channels=64, name="conv3", data_layout=layout, + kernel_layout=layers.conv_kernel_layout(layout)) + conv3 = relay.nn.bias_add(conv3, conv3_bias, bias_axis) relu3 = relay.nn.relu(conv3) bf1 = relay.nn.batch_flatten(relu3) @@ -58,7 +63,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" return relay.Function(args, dense2) -def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): """Get benchmark workload for a Deep Q Network Parameters ---------- @@ -72,10 +77,10 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo The data type Returns ------- - mod : tvm.IRModule + mod : tvm.relay.Module The relay module that contains a DQN network. params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype) + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, layout=layout) return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index b431dd096f9d..8633879465bd 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -162,6 +162,8 @@ def resnet(units, data = relay.var("data", shape=data_shape, dtype=dtype) data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') (_, _, height, _) = data_shape + if layout == "NHWC": + (_, height, _, _) = data_shape if height <= 32: # such as cifar10 body = layers.conv2d( data=data, channels=filter_list[0], kernel_size=(3, 3), @@ -209,6 +211,8 @@ def get_net(batch_size, Original author Wei Wu """ (_, height, _) = image_shape + if layout == "NHWC": + (height, _, _) = image_shape data_shape = (batch_size,) + image_shape if height <= 28: num_stages = 3 diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 7d73bf42ab7d..6539aabaa48f 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -57,8 +57,10 @@ class Tensor(DataProducer, _expr.ExprOp): def __call__(self, *indices): ndim = self.ndim - if len(indices) != ndim: - raise ValueError("Need to provide %d index in tensor slice" % ndim) + # After ansor kernel layout rewrite, len(indices) <= ndim, + # and the indices will get modified by Ansor during schedule generation. + # if len(indices) != ndim: + # raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) args = [] for x in indices: diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 5e5a337c7bce..dc17f407d003 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -49,9 +49,10 @@ def get_network(name, model_path, batch_size, layout): input_shape = (batch_size, 100) mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size) elif name == 'dqn': - image_shape = (4, 84, 84) + layout = "NHWC" + image_shape = (84, 84, 4) input_shape = (batch_size, *image_shape) - mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype) + mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) elif name == 'mobilenet': image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) input_shape = (batch_size, *image_shape) @@ -229,7 +230,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if measure_ctx: del measure_ctx - kernel_layout_rewrite = False + kernel_layout_rewrite = False # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") @@ -245,7 +246,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 5ca0c8503662..fec301dc54bc 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -37,8 +37,8 @@ #include #include #include "transform_step.h" -#include "utils.h" -// #include "../relay/pass/kernel_layout_transform.h" +#include "search_policy/utils.h" +#include "../relay/transforms/kernel_layout_transform.h" namespace tvm { namespace ansor { @@ -595,325 +595,383 @@ std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } -// class IndexRewriter : public ExprMutator { -// public: -// IndexRewriter(const OperationMap >& placeholder_new_names, -// const OperationMap >& placeholder_new_shapes): -// placeholder_new_names_(placeholder_new_names), -// placeholder_new_shapes_(placeholder_new_shapes) {} - -// Expr Mutate_(const Call* op, const Expr& e) { -// Expr op_ = IRMutator::Mutate_(op, e); - -// const Call* call = op_.as(); - -// if (call->call_type == Call::CallType::Halide) { -// Tensor t = Downcast(call->func).output(call->value_index); -// auto it = placeholder_new_names_.find(t->op); -// if (it != placeholder_new_names_.end()) { -// const std::vector& new_names = it->second; -// const Array& new_shape = placeholder_new_shapes_.at(t->op); -// std::unordered_map name_to_arg; -// for (const auto& arg : call->args) { -// std::string axis_name; -// if (const auto* pimm = arg.as()) { -// CHECK_EQ(pimm->value, 0); -// axis_name = "IntImm"; -// } else { -// axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); -// CHECK_EQ(name_to_arg.count(axis_name), 0); -// name_to_arg[axis_name] = arg; -// } -// } - -// std::unordered_map div_factors; -// std::vector r_new_args; -// for (int i = new_names.size() - 1; i >= 0; --i) { -// auto ori_iter_name = new_names[i]; -// auto name_it = name_to_arg.find(ori_iter_name); -// CHECK(name_it != name_to_arg.end()); -// Expr ori_arg = name_it->second; - -// Expr mod_factor = new_shape[i]; - -// Expr div_factor = 1; -// if (div_factors.count(ori_iter_name)) { -// div_factor = div_factors[ori_iter_name]; -// } -// div_factors[ori_iter_name] = div_factor * new_shape[i]; - -// Expr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); - -// r_new_args.push_back(new_arg); -// } - -// Array new_args(std::make_move_iterator(r_new_args.rbegin()), -// std::make_move_iterator(r_new_args.rend())); - -// return Call::make(call->type, call->name, new_args, call->call_type, -// call->func, call->value_index); -// } -// } -// return op_; -// } - -// private: -// const OperationMap >& placeholder_new_names_; -// const OperationMap >& placeholder_new_shapes_; -// }; - -// // TODO(minminsun): spill out new functions -// void ComputeDAG::RewriteLayout( -// const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { -// ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); -// const State& state = ReplayAndInferBound(transform_steps); - -// OperationMap > placeholder_new_names; -// OperationMap > placeholder_new_shapes; -// int stage_id = -1; -// for (const auto& stage : state->stages) { -// stage_id += 1; -// const Operation& op = stage->op; -// if (op->IsInstance()) { -// const Map& attrs = op->attrs; -// if (attrs.count(layout_free_placeholders_key)) { -// const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; -// Array placeholders = Downcast>(attr_value); -// for (auto& placeholder : placeholders) { -// const auto placeholder_op = placeholder->op; - -// // Check whether this placeholder has already been handled -// if (placeholder_new_names.count(placeholder_op)) { -// continue; -// } - -// // skip the op that is not direct consumer of this placeholder, -// // mostly due to cache read/write. -// bool direct_consumer = false; -// for (auto& t : op->InputTensors()) { -// if (t->op == placeholder_op) { -// direct_consumer = true; -// break; -// } -// } -// if (!direct_consumer) { -// continue; -// } - -// std::set placeholder_axis_names; -// TensorAccessExtractor extractor; -// for (const auto& exp : op.as()->body) { -// extractor.Extract(exp); -// } -// bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || -// layout_rewrite_level == kBothRewrite); -// bool rewrite_body = (layout_rewrite_level == kComputeRewrite || -// layout_rewrite_level == kBothRewrite); -// std::ostringstream os; - -// uint i = 0; -// if (extractor.buf_accesses.count(placeholder_op)) { -// for (const auto& ev : extractor.buf_accesses[placeholder_op]) { -// for (const auto& e : ev) { -// // TODO(minminsun): check whether the extents match the shape of placeholder -// std::string axis_name; -// if (const auto* pimm = e.as()) { -// CHECK_EQ(pimm->value, 0); -// // CHECK_EQ(placeholder->shape[i].as()->value, 1); -// axis_name = "IntImm"; -// } else { -// axis_name = BaseName(CleanName(Downcast(e)->name_hint)); -// } - -// placeholder_axis_names.insert(axis_name); -// if (rewrite_placeholder) { -// os << placeholder->shape[i++] << axis_name; -// } -// } -// } - -// if (rewrite_placeholder) { -// CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); -// std::string ori_layout = os.str(); -// os.str(""); -// ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); -// } -// } - -// std::vector stage_iters; - -// auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); -// int attach_pos = -1; -// size_t iters_before_attach = 0; -// if (attach_it != state->attach_map->stage_to_attach_iter.end()) { -// auto attach = attach_it->second; -// const auto& attach_stage = state->stages[attach.first]; -// attach_pos = attach.second; -// stage_iters.insert(stage_iters.end(), -// attach_stage->iters.begin(), -// attach_stage->iters.begin() + attach_pos + 1); -// } - -// stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); - -// std::vector iters; -// for (size_t i = 0; i < stage_iters.size(); ++i) { -// const auto& iter = stage_iters[i]; -// if (iter->ori_iters.empty()) { -// iters.push_back(iter); -// } else { -// for (const Iterator& ori_iter : iter->ori_iters) { -// iters.push_back(ori_iter); -// } -// } -// if (static_cast(i) == attach_pos) { -// iters_before_attach = iters.size(); -// } -// } - -// std::vector new_names; -// Array new_shape; -// std::vector new_axis_names; -// for (const Iterator& iter : iters) { -// std::set ori_iter_names; -// ExtractOriginalIterators(iter->name, &ori_iter_names); -// // fused iters have been replaced with iter->ori_iters. -// // So there should be only one ori iter name extracted from iter->name. -// CHECK_EQ(ori_iter_names.size(), 1); -// auto ori_iter_name = BaseName(*ori_iter_names.begin()); -// new_axis_names.push_back(ori_iter_name); -// } -// for (size_t i = 0; i < new_axis_names.size(); ++i) { -// auto iter = iters[i]; -// std::string ori_iter_name; -// if (i < iters_before_attach) { -// ori_iter_name = new_axis_names[i + iters_before_attach]; -// } else { -// ori_iter_name = new_axis_names[i]; -// } -// if (placeholder_axis_names.count(ori_iter_name)) { -// os << iter->range->extent << ori_iter_name; -// new_names.push_back(ori_iter_name); -// new_shape.push_back(iter->range->extent); -// } -// } -// std::string new_layout = os.str(); -// os.str(""); -// ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); -// placeholder_new_names[placeholder_op] = new_names; -// placeholder_new_shapes[placeholder_op] = new_shape; - -// Array old_ops = pdag->ops; -// ArrayNode* pops = pdag->ops.CopyOnWrite(); - -// // Create new placeholder -// Operation new_placeholder_op; -// if (rewrite_placeholder) { -// new_placeholder_op = -// te::PlaceholderOpNode::make(placeholder_op->name, -// new_shape, -// placeholder_op.as()->dtype); -// } else { -// new_placeholder_op = placeholder_op; -// } - -// Operation new_compute_op, old_compute_op; -// if (rewrite_body) { -// Array new_body; -// IndexRewriter index_rewriter(placeholder_new_names, -// placeholder_new_shapes); -// for (auto& op : old_ops) { -// if (auto* pop = op.as()) { -// bool need_update = false; -// for (auto& t : op->InputTensors()) { -// if (t->op == placeholder_op) { -// need_update = true; -// break; -// } -// } -// if (need_update) { -// for (auto& body : pop->body) { -// new_body.push_back(index_rewriter.Mutate(body)); -// } -// old_compute_op = op; -// CHECK(!new_compute_op.defined()); -// new_compute_op = ComputeOpNode::make( -// pop->name, pop->tag, pop->attrs, pop->axis, new_body); -// } -// } -// } -// } - -// // construct the map from old_op to new_op -// std::unordered_map updated_ops; -// for (size_t i = 0; i < old_ops.size(); ++i) { -// auto old_op = old_ops[i]; -// if (rewrite_placeholder && old_op == placeholder_op) { -// pops->data[i] = new_placeholder_op; -// updated_ops[placeholder_op] = new_placeholder_op; -// } else if (rewrite_body && old_op == old_compute_op) { -// pops->data[i] = new_compute_op; -// updated_ops[old_compute_op] = new_compute_op; -// } else { -// pops->data[i] = old_op; -// } -// } - -// // Because ops is sorted in topo-order, only do one pass linear scan here. -// for (size_t i = 0; i < pops->data.size(); ++i) { -// auto old_op = Downcast(pops->data[i]); -// if (auto* pop = old_op.as()) { -// auto inputs = pop->InputTensors(); -// std::unordered_map rmap; -// for (auto input : inputs) { -// auto it = updated_ops.find(input->op); -// Operation new_op; -// while (it != updated_ops.end()) { -// new_op = it->second; -// it = updated_ops.find(new_op); -// } -// if (new_op.defined()) { -// int index = input->value_index; -// rmap[input] = new_op.output(index); -// } -// } -// if (!rmap.empty()) { -// Operation new_op = pop->ReplaceInputs(old_op, rmap); -// updated_ops[old_op] = new_op; -// pops->data[i] = new_op; -// } -// } -// } - -// pdag->init_state = StateNode::make(pdag->ops); - -// Array old_tensors = pdag->tensors; -// ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); - -// for (size_t i = 0; i < old_tensors.size(); ++i) { -// const auto& old_tensor = old_tensors[i]; -// auto it = updated_ops.find(old_tensor->op); -// Operation new_op; -// while (it != updated_ops.end()) { -// new_op = it->second; -// it = updated_ops.find(new_op); -// } -// if (new_op.defined()) { -// if (layout_rewrite_level == kBothRewrite) { -// auto index = old_tensor->value_index; -// ptensors->data[i] = new_op.output(index); -// } else if (layout_rewrite_level == kComputeRewrite) { -// TensorNode* old_tensor_node = -// const_cast(old_tensor.as()); -// old_tensor_node->op = new_op; -// } -// } -// } -// } // end for placeholder -// } -// } -// } // end for stage -// } +class IndexRewriter : public StmtExprMutator { + public: + IndexRewriter(const OperationMap >& placeholder_new_names, + const OperationMap >& placeholder_new_shapes): + placeholder_new_names_(placeholder_new_names), + placeholder_new_shapes_(placeholder_new_shapes) {} + + PrimExpr Rewrite(PrimExpr expr) { + return this->VisitExpr(expr); + } + + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + te::Tensor t = Downcast(op->producer); + auto it = placeholder_new_names_.find(t->op); + if (it != placeholder_new_names_.end()) { + const std::vector& new_names = it->second; + const Array& new_shape = placeholder_new_shapes_.at(t->op); + std::unordered_map name_to_arg; + for (const auto& arg : op->indices) { + std::string axis_name; + if (const auto* pimm = arg.as()) { + CHECK_EQ(pimm->value, 0); + axis_name = "IntImm"; + } else { + axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); + CHECK_EQ(name_to_arg.count(axis_name), 0); + name_to_arg[axis_name] = arg; + } + } + + std::unordered_map div_factors; + std::vector r_new_args; + for (int i = new_names.size() - 1; i >= 0; --i) { + auto ori_iter_name = new_names[i]; + auto name_it = name_to_arg.find(ori_iter_name); + CHECK(name_it != name_to_arg.end()); + PrimExpr ori_arg = name_it->second; + + PrimExpr mod_factor = new_shape[i]; + + PrimExpr div_factor = 1; + if (div_factors.count(ori_iter_name)) { + div_factor = div_factors[ori_iter_name]; + } + div_factors[ori_iter_name] = div_factor * new_shape[i]; + + PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); + + r_new_args.push_back(new_arg); + } + + Array new_args(std::make_move_iterator(r_new_args.rbegin()), + std::make_move_iterator(r_new_args.rend())); + + return ProducerLoad(op->producer, new_args); + } + return GetRef(op); + } + + /* + PrimExpr Mutate_(const Call* op, const PrimExpr& e) { + PrimExpr op_ = IRMutator::Mutate_(op, e); + + const Call* call = op_.as(); + + if (call->call_type == Call::CallType::Halide) { + te::Tensor t = Downcast(call->func).output(call->value_index); + auto it = placeholder_new_names_.find(t->op); + if (it != placeholder_new_names_.end()) { + const std::vector& new_names = it->second; + const Array& new_shape = placeholder_new_shapes_.at(t->op); + std::unordered_map name_to_arg; + for (const auto& arg : call->args) { + std::string axis_name; + if (const auto* pimm = arg.as()) { + CHECK_EQ(pimm->value, 0); + axis_name = "IntImm"; + } else { + axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); + CHECK_EQ(name_to_arg.count(axis_name), 0); + name_to_arg[axis_name] = arg; + } + } + + std::unordered_map div_factors; + std::vector r_new_args; + for (int i = new_names.size() - 1; i >= 0; --i) { + auto ori_iter_name = new_names[i]; + auto name_it = name_to_arg.find(ori_iter_name); + CHECK(name_it != name_to_arg.end()); + PrimExpr ori_arg = name_it->second; + + PrimExpr mod_factor = new_shape[i]; + + PrimExpr div_factor = 1; + if (div_factors.count(ori_iter_name)) { + div_factor = div_factors[ori_iter_name]; + } + div_factors[ori_iter_name] = div_factor * new_shape[i]; + + PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); + + r_new_args.push_back(new_arg); + } + + Array new_args(std::make_move_iterator(r_new_args.rbegin()), + std::make_move_iterator(r_new_args.rend())); + + return Call::make(call->type, call->name, new_args, call->call_type, + call->func, call->value_index); + } + } + return op_; + } + */ + + private: + const OperationMap >& placeholder_new_names_; + const OperationMap >& placeholder_new_shapes_; +}; + +void ComputeDAG::RewriteLayout( + const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { + ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); + const State& state = ReplayAndInferBound(transform_steps); + + OperationMap > placeholder_new_names; + OperationMap > placeholder_new_shapes; + int stage_id = -1; + for (const auto& stage : state->stages) { + stage_id += 1; + const te::Operation& op = stage->op; + if (op->IsInstance()) { + const Map& attrs = op->attrs; + if (attrs.count(layout_free_placeholders_key)) { + const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; + Array placeholders = Downcast>(attr_value); + for (auto& placeholder : placeholders) { + const auto placeholder_op = placeholder->op; + + // Check whether this placeholder has already been handled + if (placeholder_new_names.count(placeholder_op)) { + continue; + } + + // skip the op that is not direct consumer of this placeholder, + // mostly due to cache read/write. + bool direct_consumer = false; + for (auto& t : op->InputTensors()) { + if (t->op == placeholder_op) { + direct_consumer = true; + break; + } + } + if (!direct_consumer) { + continue; + } + + std::set placeholder_axis_names; + TensorAccessExtractor extractor; + for (const auto& exp : op.as()->body) { + extractor.Extract(exp); + } + bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || + layout_rewrite_level == kBothRewrite); + bool rewrite_body = (layout_rewrite_level == kComputeRewrite || + layout_rewrite_level == kBothRewrite); + std::ostringstream os; + + uint i = 0; + if (extractor.buf_accesses.count(placeholder_op)) { + for (const auto& ev : extractor.buf_accesses[placeholder_op]) { + for (const auto& e : ev) { + // TODO(minminsun): check whether the extents match the shape of placeholder + std::string axis_name; + if (const auto* pimm = e.as()) { + CHECK_EQ(pimm->value, 0); + // CHECK_EQ(placeholder->shape[i].as()->value, 1); + axis_name = "IntImm"; + } else { + axis_name = BaseName(CleanName(Downcast(e)->name_hint)); + } + + placeholder_axis_names.insert(axis_name); + if (rewrite_placeholder) { + os << placeholder->shape[i++] << axis_name; + } + } + } + + if (rewrite_placeholder) { + CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); + std::string ori_layout = os.str(); + os.str(""); + ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); + } + } + + std::vector stage_iters; + + auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); + int attach_pos = -1; + size_t iters_before_attach = 0; + if (attach_it != state->attach_map->stage_to_attach_iter.end()) { + auto attach = attach_it->second; + const auto& attach_stage = state->stages[attach.first]; + attach_pos = attach.second; + stage_iters.insert(stage_iters.end(), + attach_stage->iters.begin(), + attach_stage->iters.begin() + attach_pos + 1); + } + + stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); + + std::vector iters; + for (size_t i = 0; i < stage_iters.size(); ++i) { + const auto& iter = stage_iters[i]; + if (iter->ori_iters.empty()) { + iters.push_back(iter); + } else { + for (const Iterator& ori_iter : iter->ori_iters) { + iters.push_back(ori_iter); + } + } + if (static_cast(i) == attach_pos) { + iters_before_attach = iters.size(); + } + } + + std::vector new_names; + Array new_shape; + std::vector new_axis_names; + for (const Iterator& iter : iters) { + std::set ori_iter_names; + ExtractOriginalIterators(iter->name, &ori_iter_names); + // fused iters have been replaced with iter->ori_iters. + // So there should be only one ori iter name extracted from iter->name. + CHECK_EQ(ori_iter_names.size(), 1); + auto ori_iter_name = BaseName(*ori_iter_names.begin()); + new_axis_names.push_back(ori_iter_name); + } + for (size_t i = 0; i < new_axis_names.size(); ++i) { + auto iter = iters[i]; + std::string ori_iter_name; + if (i < iters_before_attach) { + ori_iter_name = new_axis_names[i + iters_before_attach]; + } else { + ori_iter_name = new_axis_names[i]; + } + if (placeholder_axis_names.count(ori_iter_name)) { + os << iter->range->extent << ori_iter_name; + new_names.push_back(ori_iter_name); + new_shape.push_back(iter->range->extent); + } + } + std::string new_layout = os.str(); + os.str(""); + ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); + placeholder_new_names[placeholder_op] = new_names; + placeholder_new_shapes[placeholder_op] = new_shape; + + Array old_ops = pdag->ops; + ArrayNode* pops = pdag->ops.CopyOnWrite(); + + // Create new placeholder + te::Operation new_placeholder_op; + if (rewrite_placeholder) { + new_placeholder_op = + te::PlaceholderOpNode::make(placeholder_op->name, + new_shape, + placeholder_op.as()->dtype); + } else { + new_placeholder_op = placeholder_op; + } + + te::Operation new_compute_op, old_compute_op; + if (rewrite_body) { + Array new_body; + IndexRewriter index_rewriter(placeholder_new_names, + placeholder_new_shapes); + for (auto& op : old_ops) { + if (auto* pop = op.as()) { + bool need_update = false; + for (auto& t : op->InputTensors()) { + if (t->op == placeholder_op) { + need_update = true; + break; + } + } + if (need_update) { + for (auto& body : pop->body) { + new_body.push_back(index_rewriter.Rewrite(body)); + } + old_compute_op = op; + CHECK(!new_compute_op.defined()); + new_compute_op = te::ComputeOpNode::make( + pop->name, pop->tag, pop->attrs, pop->axis, new_body); + } + } + } + } + + // construct the map from old_op to new_op + std::unordered_map updated_ops; + for (size_t i = 0; i < old_ops.size(); ++i) { + auto old_op = old_ops[i]; + if (rewrite_placeholder && old_op == placeholder_op) { + //pops->data[i] = new_placeholder_op; + pops->SetItem(i, new_placeholder_op); + updated_ops[placeholder_op] = new_placeholder_op; + } else if (rewrite_body && old_op == old_compute_op) { + //pops->data[i] = new_compute_op; + pops->SetItem(i, new_compute_op); + updated_ops[old_compute_op] = new_compute_op; + } else { + //pops->data[i] = old_op; + pops->SetItem(i, old_op); + } + } + + // Because ops is sorted in topo-order, only do one pass linear scan here. + for (size_t i = 0; i < pops->size(); ++i) { + auto old_op = Downcast(pops->at(i)); + if (auto* pop = old_op.as()) { + auto inputs = pop->InputTensors(); + std::unordered_map rmap; + for (auto input : inputs) { + auto it = updated_ops.find(input->op); + te::Operation new_op; + while (it != updated_ops.end()) { + new_op = it->second; + it = updated_ops.find(new_op); + } + if (new_op.defined()) { + int index = input->value_index; + rmap[input] = new_op.output(index); + } + } + if (!rmap.empty()) { + te::Operation new_op = pop->ReplaceInputs(old_op, rmap); + updated_ops[old_op] = new_op; + //pops->data[i] = new_op; + pops->SetItem(i, new_op); + } + } + } + + pdag->init_state = StateNode::make(pdag->ops); + + Array old_tensors = pdag->tensors; + ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); + + for (size_t i = 0; i < old_tensors.size(); ++i) { + const auto& old_tensor = old_tensors[i]; + auto it = updated_ops.find(old_tensor->op); + te::Operation new_op; + while (it != updated_ops.end()) { + new_op = it->second; + it = updated_ops.find(new_op); + } + if (new_op.defined()) { + if (layout_rewrite_level == kBothRewrite) { + auto index = old_tensor->value_index; + //ptensors->data[i] = new_op.output(index); + ptensors->SetItem(i, new_op.output(index)); + } else if (layout_rewrite_level == kComputeRewrite) { + te::TensorNode* old_tensor_node = + const_cast(old_tensor.as()); + old_tensor_node->op = new_op; + } + } + } + } // end for placeholder + } + } + } // end for stage +} void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { @@ -1273,6 +1331,30 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAG") TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") .set_body_method(&ComputeDAG::GetInitState); +TVM_REGISTER_GLOBAL("ansor.ComputeDAGRewriteLayoutFromState") +.set_body([](TVMArgs args, TVMRetValue *ret) { + ComputeDAG dag = args[0]; + State state = args[1]; + + dag.RewriteLayout(state->transform_steps, kPlaceholderRewrite); + *ret = dag; +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +.set_body([](TVMArgs args, TVMRetValue *ret) { + ComputeDAG dag = args[0]; + State state = args[1]; + LayoutRewriteLevel layout_rewrite_level = kNoRewrite; + if (args.size() >= 3) { + layout_rewrite_level = LayoutRewriteLevel(static_cast((args[2]))); + } + + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level); + *ret = Array{sch, return_tensors}; +}); +/* TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { te::Schedule sch; @@ -1280,6 +1362,7 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); return Array{sch, return_tensors}; }); +*/ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 60c1790a0cfb..c71c4f1b6586 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -146,7 +146,7 @@ class ComputeDAG: public ObjectRef { // Rewrite the the layout of "layout free" placeholders according to transform steps void RewriteLayout(const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {} + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index a192002825e6..5b063eca4337 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -219,6 +219,7 @@ class TypeSolver::Unifier : public TypeFunctor { return Type(nullptr); } + tt1 = tt2; tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 34c3487e3ef2..8bd5eca7c93d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -287,6 +287,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); + //pass_seqs.push_back(transform::KernelLayoutTransform()); } // Fast math optimizations. @@ -315,6 +316,18 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); + + if (targets.size() == 1) { + pass_seqs.push_back(transform::KernelLayoutTransform()); + pass_seqs.push_back(transform::DeFuseOps()); + pass_seqs.push_back(transform::FoldConstant()); + transform::Pass seq = transform::Sequential(pass_seqs); + const auto& it = targets.begin(); + With tctx((*it).second); + relay_module = seq(relay_module); + relay_module = transform::FuseOps()(relay_module); + } + relay_module = transform::InferType()(relay_module); // Inline the functions that have been lifted by the module scope. // diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 2aae8546248f..fde880b10f1d 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -68,6 +68,11 @@ CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->disabled = false; + char* envar = getenv("TVM_RELAY_DISABLE_BUILD_CACHE"); + if (envar != nullptr && strcmp(envar, "true") == 0) { + n->disabled = true; + } data_ = std::move(n); } diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index a5f3f6359f89..b290462a4b22 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -115,6 +115,8 @@ class CCacheKeyNode : public Object { /*! \brief The hardware target.*/ Target target; + bool disabled; + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); @@ -259,6 +261,7 @@ inline size_t CCacheKeyNode::Hash() const { } inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { + if (disabled) return false; if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && tvm::StructuralEqual()(this->source_func, other->source_func); diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc new file mode 100644 index 000000000000..f7c9037df687 --- /dev/null +++ b/src/relay/transforms/defuse_ops.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class DefuseOpsMutator : public ExprMutator { + public: + + class FuncBodyMutator : public ExprMutator { + public: + Array args_; + + FuncBodyMutator(const Array& args) : ExprMutator() { + args_ = args; + } + + Expr VisitExpr_(const VarNode* n) { + const std::string& name = n->name_hint(); + CHECK_EQ(name[0], 'p'); + std::string id_str = name.substr(1); + int id = atoi(id_str.c_str()); + CHECK(id >= 0 && size_t(id) < args_.size()); + return args_[id]; + } + }; + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + const auto* call = new_n.as(); + if (call) { + const auto* func = call->op.as(); + if (func) { + const auto& func_call = func->body.as(); + if (func_call) { + return FuncBodyMutator(call->args).Mutate(func->body); + } + } + } + return new_n; + } +}; + +Expr DeFuseOps(const Expr& expr) { + return DefuseOpsMutator().Mutate(expr); +} + +namespace transform { + +Pass DeFuseOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::DeFuseOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "DeFuseOps", + {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps") +.set_body_typed(DeFuseOps); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.cc b/src/relay/transforms/kernel_layout_transform.cc new file mode 100644 index 000000000000..681785c8123c --- /dev/null +++ b/src/relay/transforms/kernel_layout_transform.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "kernel_layout_transform.h" + +namespace tvm { +namespace relay { + +// Todo: do not use global variables +std::deque KernelLayoutVisitor::global_ori_layouts_queue; +std::deque KernelLayoutVisitor::global_new_layouts_queue; + +Expr KernelLayoutTransform(const Expr& expr) { + KernelLayoutVisitor visitor; + + // Do a pre-order DFS to gather the optimal kernel layouts for all conv2d nodes. + // These layouts were written to global static variables in python function `prepare_layout_rewrite` + visitor.VisitExpr(expr); + + // Do a post-order DSF to mutate layout for all conv2d nodes + return KernelLayoutTransformer(&visitor).Mutate(expr); +} + +namespace transform { + +Pass KernelLayoutTransform() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::KernelLayoutTransform(f)); + }; + return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", + {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform") +.set_body_typed(KernelLayoutTransform); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h new file mode 100644 index 000000000000..b4b806c20e28 --- /dev/null +++ b/src/relay/transforms/kernel_layout_transform.h @@ -0,0 +1,75 @@ +#include +#include +#include +#include + +#include "pattern_util.h" + +#include "../../ansor/compute_dag.h" + +namespace tvm { +namespace relay { + +/*! \brief A visitor to gather the optimal kernel layout for all conv2d nodes. */ +class KernelLayoutVisitor : public ExprVisitor { + public: + void VisitExpr_(const CallNode *n) { + if (n && n->op.as() && + (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != + op_white_lists.end()) && n->args[1]->type_as()->shape[3].as()->value > 1 && + !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { + ori_layouts_map[n] = global_ori_layouts_queue.front(); + new_layouts_map[n] = global_new_layouts_queue.front(); + std::cout << "ori_layout " << global_ori_layouts_queue.front() << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; + global_ori_layouts_queue.pop_front(); + global_new_layouts_queue.pop_front(); + } + ExprVisitor::VisitExpr_(n); + } + + std::unordered_map ori_layouts_map; + std::unordered_map new_layouts_map; + std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; + + static std::deque global_ori_layouts_queue; + static std::deque global_new_layouts_queue; +}; + + +/*! \brief A mutator to rewrite kernel layout for all conv2d nodes */ +class KernelLayoutTransformer : public ExprMutator { + public: + KernelLayoutTransformer(KernelLayoutVisitor* visitor): ExprMutator(), visitor_(visitor) {} + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + const auto* call = new_n.as(); + std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; + if (call && call->op.as() && + (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != + op_white_lists.end() && n->args[1]->type_as()->shape[3].as()->value > 1)) { + auto ori_layout_iter = visitor_->ori_layouts_map.find(n); + auto new_layout_iter = visitor_->new_layouts_map.find(n); + if (ori_layout_iter != visitor_->ori_layouts_map.end() && + new_layout_iter != visitor_->new_layouts_map.end()) { + const std::string& ori_layout = ori_layout_iter->second; + const std::string& new_layout = new_layout_iter->second; + Expr updated_kernel = MakeKernelLayoutTransform(call->args[1], ori_layout, new_layout); + Array updated_args = {call->args[0], updated_kernel}; + new_n = Call(call->op, updated_args, + call->attrs); + } + } + return new_n; + } + + private: + KernelLayoutVisitor* visitor_; +}; + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 7518eb9ac81a..a9d3b5168e47 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -685,6 +685,8 @@ Expr MakeExpandDims(Expr data, int axis, int num_newaxis); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); +Expr MakeKernelLayoutTransform(Expr data, String src_layout, String dst_layout); + Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 4c7941b49692..de02367a4dff 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -342,7 +342,24 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - kernel_h, kernel_w, channel, num_filter = Filter.shape + if len(Filter.shape) == 10: + kernel_h = Filter.shape[2] * Filter.shape[6] + kernel_w = Filter.shape[3] * Filter.shape[7] + channel = Filter.shape[4] * Filter.shape[8] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] + elif len(Filter.shape) == 11: + kernel_h = Filter.shape[3] * Filter.shape[7] + kernel_w = Filter.shape[4] * Filter.shape[8] + channel = Filter.shape[5] * Filter.shape[9] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] + elif len(Filter.shape) == 12: + kernel_h = Filter.shape[4] * Filter.shape[8] + kernel_w = Filter.shape[5] * Filter.shape[9] + channel = Filter.shape[6] * Filter.shape[10] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[3] * Filter.shape[7] * Filter.shape[11] + else: + kernel_h, kernel_w, channel, num_filter = Filter.shape + # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -362,8 +379,9 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): lambda nn, yy, xx, ff: te.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc") + Filter[ry, rx, rc, ff].astype(out_dtype) + , axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc", attrs={"layout_free_placeholders": [Filter]}) return Output From 145e61cf072b5e976eac07484beb25711b222c25 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Sat, 20 Jun 2020 00:29:19 +0800 Subject: [PATCH 30/78] [cache flush] port cache flush to ansor (#32) --- scripts/tune_test.py | 3 ++- src/runtime/rpc/rpc_module.cc | 31 +++++++++++++++++++++++++++++++ src/runtime/threading_backend.cc | 9 +++++++-- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index a49ecd088afc..7831aea9dd4a 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -22,7 +22,8 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) runner = measure_ctx.runner else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" + runner = ansor.LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 89f3e7c6c7f8..b95d5ba25926 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -24,9 +24,14 @@ #include #include +#include #include #include +#if defined(_M_X64) || defined(__x86_64__) +#include +#endif + #include "rpc_endpoint.h" #include "rpc_session.h" @@ -300,6 +305,23 @@ std::shared_ptr RPCModuleGetSession(Module mod) { return rmod->sess(); } +inline void CacheFlush(const char* p, unsigned int allocation_size) { +// TODO: (FrozenGene) +// Support ARM. +#if (defined(_M_X64) || defined(__x86_64__)) + size_t cache_line = 64; + + if (p == nullptr || allocation_size <= 0) { + return; + } + + for (size_t i = 0; i < allocation_size; i += cache_line) { + _mm_clflush(static_cast(&p[i])); + } + +#endif +} + PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, int min_repeat_ms) { CHECK(pf != nullptr); @@ -313,12 +335,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repe auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { TVMRetValue temp; std::ostringstream os; + const char* cache_flush = std::getenv("TVM_AUTO_CACHE_FLUSH"); // skip first time call, to activate lazy compilation components. pf.CallPacked(args, &temp); DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); for (int i = 0; i < repeat; ++i) { + if (cache_flush && std::atoi(cache_flush) != 0) { + CHECK_EQ(number, 1); + // we want to keep input data + for (int j = 1; j < args.size(); j++) { + CacheFlush((char*)(args[j].operator DLTensor*()->data), + GetDataSize(*(args[j].operator DLTensor*()))); + } + } std::chrono::time_point tbegin, tend; double duration_ms = 0.0; diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index e5520efe30a6..3b1889aed8ef 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -166,8 +166,13 @@ class ThreadGroup::Impl { #if defined(_M_X64) || defined(__x86_64__) big_count /= 2; // ignore hyper-threading #endif - for (int i = 0; i < big_count; ++i) { - CPU_SET(sorted_order_[i], &cpuset); + const char* bind_master_core_0 = getenv("TVM_BIND_MASTER_CORE_0"); + if (bind_master_core_0 && atoi(bind_master_core_0) != 0) { + CPU_SET(sorted_order_[0], &cpuset); + } else { + for (int i = 0; i < big_count; ++i) { + CPU_SET(sorted_order_[i], &cpuset); + } } } #if defined(__ANDROID__) From 2c2781690313894dd35f578a8be48f940e8d7125 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 19 Jun 2020 17:42:00 -0700 Subject: [PATCH 31/78] Improve relay integration (#34) * tmp checkpoint * Improve relay integration * Improve relay integration --- python/tvm/ansor/__init__.py | 11 +- python/tvm/ansor/auto_schedule.py | 67 +++-- python/tvm/ansor/compute_dag.py | 50 +--- python/tvm/ansor/cost_model/cost_model.py | 2 + python/tvm/ansor/dispatcher.py | 18 +- python/tvm/ansor/env.py | 2 +- python/tvm/ansor/feature.py | 7 +- python/tvm/ansor/loop_state.py | 4 +- python/tvm/ansor/measure.py | 10 +- python/tvm/ansor/relay_integration.py | 259 ++++++++++-------- python/tvm/ansor/serialization.py | 7 +- python/tvm/ansor/topi_integration.py | 220 --------------- python/tvm/ansor/utils.py | 6 +- python/tvm/relay/backend/compile_engine.py | 5 +- python/tvm/relay/build_module.py | 7 + python/tvm/relay/op/strategy/x86.py | 63 +++-- python/tvm/relay/testing/resnet.py | 17 +- scripts/tune_network.py | 15 +- scripts/tune_test.py | 2 +- src/ansor/compute_dag.cc | 5 - src/ansor/compute_dag.h | 2 +- src/ansor/feature.cc | 21 +- src/ansor/measure.cc | 24 +- src/ansor/measure.h | 21 +- .../search_policy/meta_tile_rewrite_policy.cc | 136 ++++----- src/ansor/search_policy/search_policy.cc | 48 ++-- src/ansor/search_policy/search_policy.h | 41 +-- src/ansor/search_policy/utils.cc | 169 ------------ src/ansor/search_policy/utils.h | 8 - src/ansor/serialization.cc | 1 + src/relay/backend/build_module.cc | 21 +- .../transforms/kernel_layout_transform.h | 3 +- ...ion.py => test_ansor_relay_integration.py} | 79 ++++-- topi/python/topi/ansor.py | 95 ------- topi/python/topi/arm_cpu/__init__.py | 5 - topi/python/topi/generic/__init__.py | 5 - topi/python/topi/nn/conv2d.py | 47 ++-- topi/python/topi/x86/__init__.py | 5 - tutorials/ansor/tune_conv2d_cuda.py | 4 +- tutorials/ansor/tune_simple_subgraph.py | 4 +- 40 files changed, 551 insertions(+), 965 deletions(-) delete mode 100644 python/tvm/ansor/topi_integration.py rename tests/python/unittest/{test_ansor_relay_Integration.py => test_ansor_relay_integration.py} (53%) delete mode 100644 topi/python/topi/ansor.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index b43b21a60144..977e100e63c6 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,10 +28,9 @@ from . import task_scheduler # Shortcut -from .compute_dag import ComputeDAG, LayoutRewriteLevel, gen_schedule +from .compute_dag import ComputeDAG, LayoutRewriteLevel from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ - PreLoadMeasuredStates, PreAddCustomRule -from .auto_schedule import auto_schedule + PreloadMeasuredStates, PreAddCustomRule, auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel @@ -41,7 +40,7 @@ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext -from .topi_integration import register_topi_schedule, TaskExtractEnv + FallbackContext, clear_fallback_cache, ApplyGraphBest from .relay_integration import extract_from_program, extract_from_multiple_program, \ - finish_layout_rewrite, prepare_layout_rewrite + finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi +from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 232c24ee89ea..acf8982d6e89 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Meta information for a search task""" +"""User interface for auto-scheduler""" import random @@ -29,35 +29,36 @@ @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): """ + The parameters of target hardware + Parameters ---------- - num_cores : Int - vector_unit_bytes : Int - cache_line_bytes : Int - max_unroll_vec : Int - max_innermost_split_factor : Int + num_cores : int + vector_unit_bytes : int + cache_line_bytes : int + max_unroll_vec : int + max_innermost_split_factor : int """ - def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, max_innermost_split_factor): self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, vector_unit_bytes, cache_line_bytes, - max_unroll_vec, - max_innermost_split_factor) + max_unroll_vec, max_innermost_split_factor) @tvm._ffi.register_object("ansor.SearchTask") class SearchTask(Object): """ + The meta-information of a search task + Parameters ---------- dag : ComputeDAG - workload_key : Str - target : tvm.target - target_host : tvm.target + workload_key : str + target : tvm.target.Target + target_host : tvm.target.Target hardware_params : HardwareParams """ - def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None): self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, @@ -67,10 +68,10 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): - """ The base search policy class - """ + """ The base class for search policy """ def continue_search(self, task, num_measure, verbose, measurer): - return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) + return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, + num_measure, verbose, measurer) def set_task(self, task): _ffi_api.SearchPolicySetTask(self, task) @@ -89,7 +90,7 @@ class MetaTileRewritePolicy(SearchPolicy): Parameters ---------- program_cost_model: CostModel - Cost model for complete programs + Cost model for programs params: int Parameters of the search policy, go meta_tile_rewrite_policy.h to find the definitions. See code below to find the default values @@ -130,21 +131,22 @@ def __init__(self, @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): + """Callback function before or after search process""" pass -@tvm._ffi.register_object("ansor.PreLoadMeasuredStates") -class PreLoadMeasuredStates(SearchCallback): - """ A SearchCallback that used for search policy to load measured hash - from the log file. +@tvm._ffi.register_object("ansor.PreloadMeasuredStates") +class PreloadMeasuredStates(SearchCallback): + """ A SearchCallback to load measured states from the log file for a search policy. + This can resume the state of the search policy. Parameters ---------- - filename: Str + filename: str """ def __init__(self, filename: str): self.__init_handle_by_constructor__( - _ffi_api.PreLoadMeasuredStates, filename) + _ffi_api.PreloadMeasuredStates, filename) @tvm._ffi.register_object("ansor.PreAddCustomRule") @@ -153,8 +155,10 @@ class PreAddCustomRule(SearchCallback): A SearchCallback for MetaTileRewritePolicy that allowing users to add custom sketch rule. - Notice: This is an advanced feature, make sure you're clear how it - works and this should only be used in MetaTileRewritePolicy. + Notes + ----- + This is an advanced feature. Make sure you're clear how it + works and this should only be used in MetaTileRewritePolicy. Parameters ---------- @@ -193,7 +197,7 @@ class TuneOption(Object): pre_search_callbacks: List[SearchCallback] Callback functions called before the search process Candidates: - - ansor.PreLoadMeasuredStates + - ansor.PreloadMeasuredStates - ansor.PreAddCustomRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, @@ -225,7 +229,7 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, def auto_schedule(workload, target=None, target_host=None, search_policy='default', hardware_params=None, tune_option=None): - """ Do auto schedule for a compute declaration. + """ Do auto scheduling for a computation declaration. The workload parameter can be a `string` as workload_key, or directly passing a `SearchTask` as input. @@ -233,21 +237,15 @@ def auto_schedule(workload, target=None, Parameters ---------- workload : Union[SearchTask, str] - target : Target - target_host : Target = None - search_policy : Union[SearchPolicy, str] - hardware_params : HardwareParams - tune_option : TuneOption Returns ------- sch : tvm.Schedule - tensors : List[Tensor] """ if isinstance(search_policy, str): @@ -267,5 +265,4 @@ def auto_schedule(workload, target=None, sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) return sch, tensors else: - raise ValueError("Invalid workload: " + workload + - ". Expect a string or SearchTask") + raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index c54c14ec123a..f35c9d8221f3 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -19,7 +19,6 @@ import tvm._ffi from tvm.runtime import Object -from tvm import te from .loop_state import State, StateObject from . import _ffi_api @@ -34,11 +33,12 @@ class LayoutRewriteLevel(object): @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): """ + Computation declaration graph + Parameters ---------- tensors : List[Tensor] """ - def __init__(self, tensors): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) @@ -51,29 +51,20 @@ def get_init_state(self): """ return State(_ffi_api.ComputeDAGGetInitState(self), self) - def apply_steps_from_state(self, state, layout_rewrite_level=None): + def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): """ Parameters ---------- state : StateObject - layout_rewrite_level : LayoutRewriteLevel(***) + layout_rewrite_level : LayoutRewriteLevel Returns ------- sch : Schedule args : List[Tensor] """ - if isinstance(state, State): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object, - layout_rewrite_level) - elif isinstance(state, StateObject): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state, - layout_rewrite_level) - else: - raise ValueError("The input must be a State or StateObject") - - def rewrite_layout_from_state(self, state: State): - return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite_level) def print_python_code_from_state(self, state): """ @@ -85,12 +76,8 @@ def print_python_code_from_state(self, state): ------- str : Str """ - if isinstance(state, State): - return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state.state_object) - elif isinstance(state, StateObject): - return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) - else: - raise ValueError("The input must be a State or StateObject") + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) def infer_bound_from_state(self, state): """ @@ -102,19 +89,8 @@ def infer_bound_from_state(self, state): ------- state : StateObject """ - if isinstance(state, State): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object), self) - elif isinstance(state, StateObject): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state), self) - else: - raise ValueError("The input must be a State or StateObject") - -def gen_schedule(state, bufs): - if not state or not state.complete: - return te.create_schedule([x.op for x in bufs]) - else: - dag = ComputeDAG(bufs) - # only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, - # since kernel layout has already been rewritten in relay pass - schedule, _ = dag.apply_steps_from_state(state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) - return schedule + state_obj = state if isinstance(state, StateObject) else state.state_object + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + + def rewrite_layout_from_state(self, state: State): + return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index fd9b67927185..47ea5092b302 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -34,6 +34,7 @@ class RandomModel(Object): def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) + # A random number generator func for c++'s RandomModel @tvm._ffi.register_func("ansor.cost_model.random_number") def random_number(n, return_ptr): @@ -43,6 +44,7 @@ def random_number(n, return_ptr): array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = np.random.uniform(0, 1, (n,)) + @tvm._ffi.register_object("ansor.PythonBasedModel") class PythonBasedModel(CostModel): def __init__(self): diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 2f00c355d285..0ef07197ea92 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -36,9 +36,7 @@ from decorator import decorate from tvm import target as _target -from tvm.tir.expr import StringImm, FloatImm - -from .loop_state import State, StateObject +from tvm.tir.expr import FloatImm logger = logging.getLogger('auto_scheduler') @@ -360,19 +358,6 @@ def update(self, target, workload, state): self._best_user_defined[key] = state -class BlockingEmptyContext(DispatchContext): - """ - An empty context which returns emtpy State() for all queries. - This also blocks the queries, so the queries won't affect the global FallbackContext. - """ - def __init__(self): - super(BlockingEmptyContext, self).__init__() - - def query(self, target, workload): - #return StateObject() - return None - - class FallbackContext(DispatchContext): """ A fallback dispatch context. @@ -400,7 +385,6 @@ def _query_inside(self, target, workload): if msg not in self.messages: self.messages.add(msg) logger.warning(msg) - #cfg = StateObject() cfg = None # cache this config diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 6d2bbd2c92af..9e44ad66048b 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -1,4 +1,4 @@ -""" The scope to store global variables in auto_scheduelr """ +""" The scope to store global variables in ansor """ class AutoschedulerGlobalScope(object): def __init__(self): diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index 4f9fdeb9e6cd..9496533da6cc 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -""""Python API for Feature extraction. +"""" +Python API for Feature extraction. The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` """ @@ -28,8 +29,10 @@ from . import _ffi_api +# Maximum number of buffers for one statement to extract feature for DEFAULT_MAX_N_BUFS = 5 +# The length of the feature vector DEFAULT_FEATURE_VEC_LEN = 164 @@ -145,6 +148,6 @@ def get_per_stmt_features_from_states(states, def get_per_stmt_feature_names(max_n_bufs: int = None) -> List[str]: - """Get names of the elements in the flatten feature vector""" + """Get names for the elements in the flatten feature vector""" return [x for x in _ffi_api.GetPerStmtFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)] diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 23289c027293..3c60c3f09a8d 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -26,9 +26,9 @@ We don't use the existing TVM IR because 1. We want fast incremental change to the loop structures 2. We want serializable history for replay and backtracking -3. We may create some Macro schedule primitives +3. We may create some new macro schedule primitives -After search is done, we will lower this IR to TVM IR with TVM schedule primitives. +After search is done, we will lower this IR to TVM IR with TVM's schedule primitives. Because we share a lot common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 8b38f91647b2..3d9c33860cae 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -38,18 +38,20 @@ from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server from tvm.autotvm.measure.measure_methods import set_cuda_target_arch -from ..contrib import tar, ndk +from tvm.contrib import tar, ndk +from . import _ffi_api from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel -from . import _ffi_api logger = logging.getLogger('ansor') +# The maximum length of error message MAX_ERROR_MSG_LEN = 512 @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): + """Base class for measurement callback function""" pass @tvm._ffi.register_object("ansor.MeasureInput") @@ -103,7 +105,7 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.Builder") class Builder(Object): - def build(self, measure_inputs, verbose=0): + def build(self, measure_inputs, verbose=1): """ Parameters ---------- @@ -119,7 +121,7 @@ def build(self, measure_inputs, verbose=0): @tvm._ffi.register_object("ansor.Runner") class Runner(Object): - def run(self, measure_inputs, build_results, verbose=0): + def run(self, measure_inputs, build_results, verbose=1): """ Parameters ---------- diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 383471ee060d..85c4d8813f69 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -15,90 +15,33 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-variable,invalid-name -""" -Decorator and utilities for the integration with TOPI and Relay -99.9% copy-paste of implementation by @MerryMercy +""" +Integrate ansor into relay. It implements the following items: +1. Extract search tasks from a relay program +2. Provide auto-scheduling for all TOPI compute functions """ import os -os.environ['TVM_USE_AUTO_SCHEDULER'] = 'true' - +import json import threading -import warnings -import tvm - -from .topi_integration import TaskExtractEnv -from .dispatcher import BlockingEmptyContext +from tvm import target, te, transform +from tvm.te.tensor import PlaceholderOp, ComputeOp +from .dispatcher import DispatchContext +from .workload_registry import register_auto_scheduler_workload_bufs, compute_dag_hash +from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE -def _lower(mod, - target, - params): - """ Helper to lower VTA properly. - """ +def call_all_topi_funcs(mod, target, params): + """Call all TOPI compute + schedule to extract tasks in a relay program""" # pylint: disable=import-outside-toplevel from tvm import relay - from tvm.relay.backend import graph_runtime_codegen - - if hasattr(target, 'device_name') and target.device_name == "vta": - import vta - with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(mod["main"]) - return - # default case - # Try graph codegen first to extract autotvm tasks. - # If failed to compile, then fallback to use VM compiler. - # TODO: Currently VM compiler is likely to stack overflow for large models. - try: - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - opt_mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(opt_mod["main"]) - except tvm.TVMError: - compiler = relay.vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target=target) - -OP_TO_SCHEDULE = {} - -def init_op_to_schedule_map(): - # init the global map OP_TO_SCHEDULE inside a function, this is used to resolve import issues - global OP_TO_SCHEDULE - from tvm import relay - import topi - - if OP_TO_SCHEDULE: - return - - OP_TO_SCHEDULE = { - relay.op.nn.conv2d: [topi.generic.schedule_conv2d_nchw, - topi.generic.schedule_conv2d_nhwc, - topi.generic.schedule_depthwise_conv2d_nchw, - topi.generic.schedule_depthwise_conv2d_nhwc, - topi.generic.schedule_group_conv2d_nchw, - topi.generic.schedule_conv2d_winograd_without_weight_transform], - relay.op.nn.conv2d_transpose: [topi.generic.schedule_conv2d_transpose_nchw], - relay.op.nn.dense: [topi.generic.schedule_dense], - relay.op.nn.softmax: [topi.generic.schedule_softmax], - relay.op.nn.max_pool2d: [topi.generic.schedule_pool], - relay.op.nn.avg_pool2d: [topi.generic.schedule_pool], - relay.op.nn.global_avg_pool2d: [topi.generic.schedule_adaptive_pool], - relay.op.nn.global_max_pool2d: [topi.generic.schedule_adaptive_pool], - relay.op.nn.deformable_conv2d: [topi.generic.schedule_deformable_conv2d_nchw], - relay.op.mean: [topi.generic.schedule_reduce], - relay.op.prod: [topi.generic.schedule_reduce], - relay.op.nn.conv3d: [topi.generic.schedule_conv3d_ncdhw, - topi.generic.schedule_conv3d_ndhwc], - relay.op.nn.adaptive_avg_pool3d: [topi.generic.schedule_adaptive_pool], - relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul], - } - -def extract_from_program(mod, params, target, target_host=None, ops=None): + with transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + bld_mod = relay.build_module.BuildModule() + bld_mod.call_all_topi_funcs(mod, target=target, params=params) + +def extract_from_program(mod, params, target, target_host=None): """ Extract tuning tasks from a relay program. This function is the single program version of extract_from_multiple_program. @@ -120,14 +63,11 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): ------- workloads: Array of Tuple(wkl_key, target) """ - return extract_from_multiple_program([mod], [params], target, target_host, ops) + return extract_from_multiple_program([mod], [params], target, target_host) -def extract_from_multiple_program(mods, params, target, target_host=None, ops=None): +def extract_from_multiple_program(mods, params, target, target_host=None): """ Extract tuning tasks from multiple relay programs. - This function collects tuning tasks by building a list of programs - with a "tracing" target and tracing all the calls to topi. - Parameters ---------- mods : List of relay.Module @@ -145,35 +85,17 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No ------- workloads: Array of Tuple(wkl_key, target) """ + # pylint: disable=import-outside-toplevel from tvm import relay - env = TaskExtractEnv.get() - - init_op_to_schedule_map() - topi_scheds = [] - - if not ops: - ops = [relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, - relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d, - relay.op.nn.avg_pool2d, relay.op.nn.global_max_pool2d, - relay.op.nn.global_avg_pool2d, relay.op.nn.conv3d, - relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul, - relay.op.mean] - - for op_name in ops: - if op_name in OP_TO_SCHEDULE: - topi_scheds.extend(OP_TO_SCHEDULE[op_name]) - else: - warnings.warn("Op %s is not tunable, ignored." % op_name) - - # run compiler to collect all TOPI calls during compilation - env.reset(topi_scheds) + env = TracingEnvironment(TracingMode.EXTRACT_TASK) with env: + # run compiler to collect all TOPI calls during compilation for mod, param in zip(mods, params): - # wrap build call in thread to avoid multiprocessing problems - with BlockingEmptyContext(): - build_thread = threading.Thread(target=_lower, - args=(mod, target, param)) + # wrap build call in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + build_thread = threading.Thread(target=call_all_topi_funcs, + args=(mod, target, param)) build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() @@ -181,32 +103,26 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No # create tasks for target wkl_keys = [] wkl_weights = [] - for wkl_key, wkl_weight in env.get_wkl_keys().items(): + for wkl_key, wkl_weight in env.wkl_key_collection.items(): wkl_keys.append(wkl_key) wkl_weights.append(wkl_weight) return wkl_keys, wkl_weights -def prepare_layout_rewrite(mod, params, ops, target): - """Prepare for kernel layout rewrite. This function will write layout infos to a global static variable, - then these layout info will be used by a relay pass `kernel_layout_transform`. + +def prepare_layout_rewrite(mod, params, target): + """ + Prepare for kernel layout rewrite. This function will write layout infos to a global static variable. + Then these layout info will be used by a relay pass `kernel_layout_transform`. """ + # pylint: disable=import-outside-toplevel from tvm import relay - env = TaskExtractEnv.get(do_layout_rewrite=True) - - init_op_to_schedule_map() - topi_scheds = [] - for op_name in ops: - if op_name in OP_TO_SCHEDULE: - topi_scheds.extend(OP_TO_SCHEDULE[op_name]) - else: - warnings.warn("Op %s is not tunable, ignored." % op_name) - - env.reset(topi_scheds) + env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) with env: - # wrap build call in thread to avoid multiprocessing problems - build_thread = threading.Thread(target=_lower, + # wrap build call in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + build_thread = threading.Thread(target=call_all_topi_funcs, args=(mod, target, params)) build_thread.start() build_thread.join() @@ -218,3 +134,104 @@ def prepare_layout_rewrite(mod, params, ops, target): def finish_layout_rewrite(): """Clear the global flag for layout rewrite""" GLOBAL_SCOPE.topi_in_compute_rewrite_mode = False + + +class TracingMode: + """Two modes for tracing""" + EXTRACT_TASK = 0 # trace all topi calls to extract tasks + PREPARE_LAYOUT_REWRITE = 1 # trace all topi calls to prepare layout rewrite + +class TracingEnvironment: + """Global environment for tracing all topi function calls""" + current = None + + def __init__(self, tracing_mode): + self.tracing_mode = tracing_mode + self.relay_disable_build_cache = "false" + self.layout_rewrite_success_ct = 0 + self.wkl_key_collection = {} + + def __enter__(self): + self.relay_disable_build_cache = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" + TracingEnvironment.current = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache + TracingEnvironment.current = None + + def add_workload_key(self, key): + """Add the workload key of an Ansor search task + + Parameters + ---------- + key: str + """ + if key in self.wkl_key_collection: + self.wkl_key_collection[key] += 1 + else: + self.wkl_key_collection[key] = 1 + + +def traverse_to_get_io_tensors(outs): + """Traverse from a list of output tensors to get a whole computational DAG""" + layout_free_ops = [] + inputs = [] + + visited = set() + + def traverse(t): + if t in visited: + return + if isinstance(t.op, PlaceholderOp): + inputs.append(t) + elif isinstance(t.op, ComputeOp): + if "layout_free_placeholders" in t.op.attrs: + layout_free_ops.append(t.op) + for x in t.op.input_tensors: + traverse(x) + visited.add(t) + + for t in outs: + traverse(t) + + has_layout_free = (len(layout_free_ops) > 0) + return inputs + [t for t in outs], has_layout_free + + +def auto_schedule_topi(outs): + """ Use ansor to auto-schedule a topi compute declaration """ + io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) + key = register_auto_scheduler_workload_bufs(io_tensors) + + env = TracingEnvironment.current + if env is None: # in the final build mode + state = DispatchContext.current.query(target.Target.current(), key) + dag = ComputeDAG(io_tensors) + # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, + # Since kernel layout has already been rewritten in relay pass + schedule, _ = dag.apply_steps_from_state(state, + layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) + return schedule + elif env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode + env.add_workload_key(key) + return te.create_schedule([x.op for x in outs]) + elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: + # in prepare_layout_rewrite mode + if has_layout_free: + # Rewrite the DAG and update the transform history for + # the new dag in DispatchContext + dispatch_ctx = DispatchContext.current + tgt = target.Target.current() + state = dispatch_ctx.query(tgt, key) + assert state is not None + dag = ComputeDAG(outs) + new_dag = dag.rewrite_layout_from_state(state) + new_key = json.dumps((compute_dag_hash(new_dag),)) + dispatch_ctx.update(tgt, new_key, state) + if new_key != key: + env.layout_rewrite_success_ct += 1 + return te.create_schedule([x.op for x in outs]) + else: + raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index d9b8a2f5c075..97903b38bb0b 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Tuning log I/O Utilities""" +"""Serialization and other I/O support for tuning logs (measurement records)""" import numpy as np @@ -29,7 +29,7 @@ @tvm._ffi.register_object("ansor.LogToFile") class LogToFile(MeasureCallback): """ - A measurement callback that writes tuning logs into a file + A measurement callback that writes measurement records into a file Parameters ---------- @@ -65,6 +65,7 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) def load_from_file(filename: str): + """Load measurement records from a file""" return zip(*LogReader(filename).read_lines()) @@ -80,7 +81,7 @@ def get_states_from_measure_inputs(inputs, task): def best_measure_pair_in_file(filename, workload_key=None, target=None): - """ Return best results form log file + """ Return the best measurement pair form a log file Parameters ---------- diff --git a/python/tvm/ansor/topi_integration.py b/python/tvm/ansor/topi_integration.py deleted file mode 100644 index 77def00cf9ec..000000000000 --- a/python/tvm/ansor/topi_integration.py +++ /dev/null @@ -1,220 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-variable,invalid-name,unused-argument -""" -Decorators for registering tunable templates to TOPI. - -These decorators can make your simple implementation be able to use different configurations -for different workloads. -Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments -(except tvm.te.Tensor) in you calls are hashable. For tvm.te.Tensor, -we will serialize it to a hashable tuple. - -See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. -""" -import os -import json -import tvm.te._ffi_api -from tvm import target as _target -from tvm.te import tensor -from tvm.te.tensor import PlaceholderOp, ComputeOp - -from .dispatcher import DispatchContext, BlockingEmptyContext -from .workload_registry import register_auto_scheduler_workload_bufs, \ - make_workload_key_bufs, compute_dag_hash -from .compute_dag import ComputeDAG - -def traverse_to_get_io_tensors(outs): - layout_free_ops = [] - inputs = [] - - visited = set() - - def traverse(t): - if t in visited: - return - if isinstance(t.op, PlaceholderOp): - inputs.append(t) - elif isinstance(t.op, ComputeOp): - if "layout_free_placeholders" in t.op.attrs: - layout_free_ops.append(t.op) - for x in t.op.input_tensors: - traverse(x) - visited.add(t) - - for t in outs: - traverse(t) - - has_layout_free = (len(layout_free_ops) > 0) - return inputs + [t for t in outs], has_layout_free - -# Task extractor for relay program -class TaskExtractEnv: - """Global environment for extracting tuning tasks from graph""" - current = None - registered = None - - def __init__(self, do_layout_rewrite=False): - self.do_layout_rewrite = do_layout_rewrite - self.wanted_relay_ops = None - self.modified_funcs = [] - self.tracing = False - self.relay_disable_build_cache_ = "false" - self.layout_rewrite_success_ct = 0 - self.wkl_key_collection = {} - - def __enter__(self): - self.tracing = True - self.wkl_key_collection = {} - self.relay_disable_build_cache_ = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.tracing = False - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache_ - - def reset(self, wanted_relay_ops=None): - """Reset task collections - - Parameters - ---------- - wanted_relay_ops: List of tvm.ir.Op - The relay ops to be extracted - """ - self.wanted_relay_ops = wanted_relay_ops - self.relay_disable_build_cache_ = "false" - self.layout_rewrite_success_ct = 0 - self.wkl_key_collection = {} - - def add_task(self, key): - """Add AutoTVM task - - Parameters - ---------- - task_name: str - AutoTVM task name. - - args: tuple - Arguments to the TOPI function. - """ - if key in self.wkl_key_collection: - self.wkl_key_collection[key] += 1 - else: - self.wkl_key_collection[key] = 1 - - def get_tasks(self): - """Get collected tasks - - Returns - ------- - tasks: List of tuple(name, args) - A list of tasks extracted from the graph - """ - return self.wkl_key_collection - - def get_wkl_keys(self): - """Get collected tasks - - Returns - ------- - wkl_keys: List of autoschedule workload_key - """ - return self.wkl_key_collection - - @staticmethod - def get(do_layout_rewrite=False): - """Get the single instance of TaskExtractEnv - - Parameters - ---------- - - Returns - ------- - env: TaskExtractEnv - The single instance of TaskExtractEnv - """ - if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv(do_layout_rewrite) - else: - TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite - return TaskExtractEnv.current - -def register_topi_schedule(func=None): - """Register a tunable template for a topi schedule function. - - The registration will wrap this topi schedule to take `cfg` as the first argument, - followed by the original argument list. - - Note that this function will try to find "workload" from all the ComputeOp in the input. - You can attach "workload" to your compute op by using :any:`register_topi_compute`. - - The task name has to be the same as that of the corresponding topi compute function. - - Parameters - ---------- - task_name: str - The AutoTVM task name - - func: None or callable - If it is None, return a decorator. - If is callable, decorate this function. - - Returns - ------- - decorator: callable - A decorator - - Examples - -------- - See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. - """ - def _decorate(topi_schedule): - def wrapper(outs, *args, **kwargs): - io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_auto_scheduler_workload_bufs(io_tensors) - task_env = TaskExtractEnv.current - if task_env is not None and task_env.tracing: - if task_env.do_layout_rewrite and has_layout_free: - # Rewrite the dag and update the transform history for - # the new dag in DispatchContext - dispatch_ctx = DispatchContext.current - tgt = _target.Target.current() - state = dispatch_ctx.query(tgt, key) - dag = ComputeDAG(outs) - new_dag = dag.rewrite_layout_from_state(state) - new_key = json.dumps((compute_dag_hash(new_dag),)) - dispatch_ctx.update(tgt, new_key, state) - - if new_key != key: - task_env.layout_rewrite_success_ct += 1 - - # Call schedule_func under FallbackContext() to avoid layout rewrite - cfg = BlockingEmptyContext().query(tgt, key) - return topi_schedule(cfg, outs) - - task_env.add_task(key) - - """wrapper function for topi schedule""" - tgt = _target.Target.current() - cfg = DispatchContext.current.query(tgt, key) - return topi_schedule(cfg, outs) - return wrapper - if func: - return _decorate(func) - return _decorate diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 5ed9bd46d355..9e3c857aba36 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Common utilities""" +"""Common utilities for ansor""" import multiprocessing import multiprocessing.pool @@ -30,7 +30,7 @@ except ImportError: psutil = None -from .. import rpc as _rpc +from tvm import rpc from tvm.tir import expr from tvm.tir.transform import Simplify from tvm.ir.transform import Sequential @@ -205,7 +205,7 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60): host = host or os.environ['TVM_TRACKER_HOST'] port = port or int(os.environ['TVM_TRACKER_PORT']) - tracker = _rpc.connect_tracker(host, port) + tracker = rpc.connect_tracker(host, port) remote = tracker.request(device_key, priority=priority, session_timeout=timeout) return remote diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 8e6698e4a164..66ef5cd4c852 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -21,6 +21,7 @@ import logging import numpy as np import tvm +import os from tvm import te from tvm.runtime import Object from ... import target as _target @@ -141,7 +142,6 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): ret.append(impl) return ret - def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True): """Select the best implementation from the op strategy. @@ -179,6 +179,9 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor]) The best op implementation and the corresponding output tensors. """ + if os.environ.get('TVM_USE_AUTOTVM', 'false') == 'false': + use_autotvm = False + all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) best_plevel_impl = None diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 30c5971e32b9..d1a39ceb630e 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -72,6 +72,7 @@ def __init__(self): self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] + self._call_all_topi_funcs = self.mod["call_all_topi_funcs"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -160,6 +161,12 @@ def optimize(self, mod, target=None, params=None): return mod, params + def call_all_topi_funcs(self, mod, target=None, target_host=None, params=None): + """Call all topi compute and schedule used in a relay function""" + target = _update_target(target) + if params: + self._set_params(params) + self._call_all_topi_funcs(mod, target, target_host) def _set_params(self, params): self._set_params_func(_convert_param_map(params)) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index b02db416bdc8..2a0ddd1329b5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -16,14 +16,16 @@ # under the License. """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import -import logging -import re -import topi +import os from tvm.te import SpecializedCondition +from tvm import ansor from .generic import * from .. import op as _op +# Set the priority level to use the Ansor auto-scheduler +ansor_plevel = 11 + logger = logging.getLogger('strategy') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") @@ -39,7 +41,7 @@ def schedule_injective_cpu(attrs, outs, target): def schedule_reduce_cpu(attrs, outs, target): """schedule reduction ops for x86""" with target: - return topi.x86.schedule_reduce(outs) + return ansor.auto_schedule_topi(outs) @schedule_concatenate.register("cpu") def schedule_concatenate_cpu(attrs, outs, target): @@ -51,13 +53,13 @@ def schedule_concatenate_cpu(attrs, outs, target): def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" with target: - return topi.x86.schedule_pool(outs, attrs.layout) + return ansor.auto_schedule_topi(outs) @schedule_adaptive_pool.register("cpu") def schedule_adaptive_pool_cpu(attrs, outs, target): """schedule adaptive pooling ops for x86""" with target: - return topi.x86.schedule_adaptive_pool(outs) + return ansor.auto_schedule_topi(outs) @softmax_strategy.register("cpu") def softmax_strategy_cpu(attrs, inputs, out_type, target): @@ -65,15 +67,15 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), - wrap_topi_schedule(topi.x86.schedule_softmax), - name="softmax.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") return strategy @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" with target: - return topi.x86.schedule_softmax(outs) + return ansor.auto_schedule_topi(outs) @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): @@ -105,18 +107,18 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - logger.warning("For x86 target, NCHW layout is recommended for conv2d.") + #logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), - wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), - name="conv2d_nhwc.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") elif layout == "HWCN": assert kernel_layout == "HWIO" - logger.warning("conv2d HWCN layout is not optimized for x86.") + #logger.warning("conv2d HWCN layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), - wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), - name="conv2d_hwcn.generic") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -143,8 +145,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) else: # group_conv2d @@ -153,8 +155,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("group_conv2d is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.generic") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -231,8 +233,8 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): name="conv3d_ncdhw.x86") elif layout == "NDHWC": strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), - wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), - name="conv3d_ndhwc.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise ValueError("Not support this layout {} yet".format(layout)) return strategy @@ -251,8 +253,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): name="conv1d_ncw.x86") elif layout == "NWC": strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc), - wrap_topi_schedule(topi.x86.schedule_conv1d_nwc), - name="conv1d_nwc.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise ValueError("Unsupported conv1d layout {}".format(layout)) return strategy @@ -261,16 +263,23 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - m, _ = inputs[0].shape + + strategy.add_implementation(wrap_compute_dense(topi.nn.dense), + wrap_topi_schedule(ansor.auto_schedule_topi), + name='ansor', + plevel=ansor_plevel) + strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), name="dense_nopack.x86", plevel=10) + if "cblas" in target.libs: strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas), wrap_topi_schedule(topi.x86.schedule_dense_cblas), name="dense_cblas.x86", plevel=15) + m, _ = inputs[0].shape with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), @@ -283,6 +292,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() + + strategy.add_implementation(wrap_compute_dense(topi.nn.batch_matmul), + wrap_topi_schedule(ansor.auto_schedule_topi), + name='ansor', + plevel=ansor_plevel) + strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 8633879465bd..4383157d9f06 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -59,9 +59,11 @@ def residual_unit(data, name : str Base name of the operators """ + bn_axis = data_layout.index('C') if bottle_neck: bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, + axis=bn_axis, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( @@ -73,13 +75,13 @@ def residual_unit(data, name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), name=name + '_conv2', data_layout=data_layout, kernel_layout=kernel_layout) - bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') + bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, axis=bn_axis, name=name + '_bn3') act3 = relay.nn.relu(data=bn3) conv3 = layers.conv2d( data=act3, channels=num_filter, kernel_size=(1, 1), @@ -94,13 +96,13 @@ def residual_unit(data, data_layout=data_layout, kernel_layout=kernel_layout) return relay.add(conv3, shortcut) - bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') + bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( data=act1, channels=num_filter, kernel_size=(3, 3), strides=stride, padding=(1, 1), name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=num_filter, kernel_size=(3, 3), @@ -156,11 +158,12 @@ def resnet(units, data_layout = layout kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" + bn_axis = data_layout.index('C') num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) - data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') + data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, name='bn_data') (_, _, height, _) = data_shape if layout == "NHWC": (_, height, _, _) = data_shape @@ -174,7 +177,7 @@ def resnet(units, data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), name="conv0", data_layout=data_layout, kernel_layout=kernel_layout) - body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') + body = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn0') body = relay.nn.relu(data=body) body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), layout=data_layout) @@ -189,7 +192,7 @@ def resnet(units, body, filter_list[i+1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, data_layout=data_layout, kernel_layout=kernel_layout) - bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') + bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn1') relu1 = relay.nn.relu(data=bn1) # Although kernel is not used here when global_pool=True, we should put one pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout) diff --git a/scripts/tune_network.py b/scripts/tune_network.py index dc17f407d003..d4f1afd95572 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -200,14 +200,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if tune: print("=============== Extracting workloads ===============") - workloads, wkl_weights = ansor.extract_from_program(mod, target=target, - params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, - relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, - relay.op.nn.max_pool2d, relay.op.nn.avg_pool2d, - relay.op.nn.global_max_pool2d, relay.op.nn.global_avg_pool2d, - relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d, - relay.op.nn.batch_matmul, relay.op.mean, - )) + workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) print("Totally %d workload extracted." % (len(workloads))) # Tune workloads with auto scheduler @@ -238,15 +231,13 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" os.environ['TVM_BIND_MASTER_CORE_0'] = "1" if kernel_layout_rewrite: - ansor.prepare_layout_rewrite(mod, target=target, - params=params, - ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) + ansor.prepare_layout_rewrite(mod, target=target, params=params) else: # disable layout rewrite ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 7831aea9dd4a..86f055caf889 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -36,7 +36,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) return tune_option, measure_ctx diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index fec301dc54bc..6269b9f16f71 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -902,15 +902,12 @@ void ComputeDAG::RewriteLayout( for (size_t i = 0; i < old_ops.size(); ++i) { auto old_op = old_ops[i]; if (rewrite_placeholder && old_op == placeholder_op) { - //pops->data[i] = new_placeholder_op; pops->SetItem(i, new_placeholder_op); updated_ops[placeholder_op] = new_placeholder_op; } else if (rewrite_body && old_op == old_compute_op) { - //pops->data[i] = new_compute_op; pops->SetItem(i, new_compute_op); updated_ops[old_compute_op] = new_compute_op; } else { - //pops->data[i] = old_op; pops->SetItem(i, old_op); } } @@ -936,7 +933,6 @@ void ComputeDAG::RewriteLayout( if (!rmap.empty()) { te::Operation new_op = pop->ReplaceInputs(old_op, rmap); updated_ops[old_op] = new_op; - //pops->data[i] = new_op; pops->SetItem(i, new_op); } } @@ -958,7 +954,6 @@ void ComputeDAG::RewriteLayout( if (new_op.defined()) { if (layout_rewrite_level == kBothRewrite) { auto index = old_tensor->value_index; - //ptensors->data[i] = new_op.output(index); ptensors->SetItem(i, new_op.output(index)); } else if (layout_rewrite_level == kComputeRewrite) { te::TensorNode* old_tensor_node = diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index c71c4f1b6586..8da71f005f19 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -105,7 +105,7 @@ typedef std::unordered_map, ObjectHash void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); -/*! \brief Compute declaration graph */ +/*! \brief Computation declaration graph */ class ComputeDAGNode : public Object { public: Array tensors; // Input and output tensors diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 497a3ac4222b..3c6976a0e25a 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -653,9 +653,9 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.vec_prod *= GetIntImm(pfor->extent); } fea.vec_type = kPosMixed; - // todo(lmzheng): this feature requires operation (tvm.compute) information - //GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, - //node->args, pcompute->axis, pcompute->reduce_axis); + // todo(lmzheng): this feature requires operation (tvm.compute) information + // GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); } fea.unroll_num = unroll_for_stack.size(); @@ -666,8 +666,8 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.unroll_prod *= GetIntImm(pfor->extent); } fea.unroll_type = kPosMixed; - //GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, - //node->args, pcompute->axis, pcompute->reduce_axis); + // GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); } fea.parallel_num = parallel_for_stack.size(); @@ -678,8 +678,8 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.parallel_prod *= GetIntImm(pfor->extent); } fea.parallel_type = kPosMixed; - //GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, - //node->args, pcompute->axis, pcompute->reduce_axis); + // GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); } // GPU threads @@ -1213,7 +1213,8 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, const auto& optimize = tir::transform::Sequential(pass_list); optimize(mod); } - const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); + const auto& optimize = tir::transform::Sequential( + Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); const auto& it = mod->functions.find(global_var); CHECK(it != mod->functions.end()); @@ -1241,8 +1242,8 @@ void GetPerStmtFeaturesFromStates(const Array& states, for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], max_n_bufs, &(*features)[i], &error_ct); - //GetPerStmtFeaturesWorkerFunc(task, states[i], - // max_n_bufs, &(*features)[i], &error_ct); + // GetPerStmtFeaturesWorkerFunc(task, states[i], + // max_n_bufs, &(*features)[i], &error_ct); } pool.WaitBatch(); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 73bbade241c5..474ea048ebad 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -1,6 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * Copyright (c) 2020 by Contributors + * \file ansor/measure.cc + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ + #include "measure.h" #include #include diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 780a30514d46..6e432ba9c88b 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -1,6 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! * Copyright (c) 2020 by Contributors - * \file ansor/search_task.h + * \file ansor/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 7e022e3be3c3..8b5b97224c08 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -62,8 +62,8 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { std::vector best_states, random_states; - cur_task_ = task; - verbose_ = verbose; + this->cur_task = task; + this->verbose = verbose; num_measure_per_iter_ = num_measure_per_iter; RunCallbacks(pre_search_callbacks); @@ -85,17 +85,17 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, while (ct < n_trials) { if (!inputs.empty()) { // retrain cost models - PrintTitle("Train cost model", verbose_); + PrintTitle("Train cost model", verbose); program_cost_model->Update(inputs, results); } // Search one round to get promising states - PrintTitle("Search", verbose_); + PrintTitle("Search", verbose); SearchOneRound(&best_states, num_random, &random_states); // Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check - cur_task_->compute_dag.InferBound(&best_states); - cur_task_->compute_dag.InferBound(&random_states); + cur_task->compute_dag.InferBound(&best_states); + cur_task->compute_dag.InferBound(&random_states); // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy @@ -108,11 +108,11 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, } // Measure candidate states - PrintTitle("Measure", verbose_); - measurer->Measure(cur_task_, GetRef(this), inputs, &results); + PrintTitle("Measure", verbose); + measurer->Measure(cur_task, GetRef(this), inputs, &results); ct += inputs.size(); - if (ct - measurer->best_ct[cur_task_->workload_key] > early_stopping) { + if (ct - measurer->best_ct[cur_task->workload_key] > early_stopping) { StdCout(verbose) << "Meet the early stopping condition." << std::endl; break; } @@ -122,21 +122,21 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); } } - PrintTitle("Done", verbose_); + PrintTitle("Done", verbose); - return measurer->best_state[cur_task_->workload_key]; + return measurer->best_state[cur_task->workload_key]; } } std::pair, Array > MetaTileRewritePolicyNode::ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { - if (cur_task_.defined()) { - CHECK_EQ(cur_task_, task); + if (cur_task.defined()) { + CHECK_EQ(cur_task, task); } else { - cur_task_ = task; + cur_task = task; } - verbose_ = verbose; + this->verbose = verbose; num_measure_per_iter_ = num_measure; std::vector best_states, random_states; @@ -149,8 +149,8 @@ std::pair, Array > SearchOneRound(&best_states, num_random * 2, &random_states); // Fill correct bound. This is necessary for computing the correct ToStr() for reduncency check - cur_task_->compute_dag.InferBound(&best_states); - cur_task_->compute_dag.InferBound(&random_states); + cur_task->compute_dag.InferBound(&best_states); + cur_task->compute_dag.InferBound(&random_states); // Pick `num_measure` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy @@ -158,7 +158,7 @@ std::pair, Array > // Measure candidate states PrintTitle("Measure", verbose); - measurer->Measure(cur_task_, GetRef(this), inputs, &results); + measurer->Measure(cur_task, GetRef(this), inputs, &results); // Update throughputs of measured states. These states will join the LocalMutation in later rounds for (const auto& res : results) { @@ -219,7 +219,7 @@ void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( if (measured_states_set_.count(state_str)) { continue; } measured_states_set_.insert(state_str); - inputs->push_back(MeasureInputNode::make(cur_task_, *pstate)); + inputs->push_back(MeasureInputNode::make(cur_task, *pstate)); measured_states_vector_.push_back(std::move(*pstate)); } } @@ -288,7 +288,7 @@ class SketchGenerationRule { static inline bool ShouldBeCacheRead( const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; if (HasAttrsFlag(state, stage_id, @@ -320,7 +320,7 @@ static inline bool ShouldBeCacheRead( static inline bool ShouldAlwaysBeInlined( const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; if (stage->op->IsInstance()) { @@ -336,7 +336,7 @@ static inline bool ShouldAlwaysBeInlined( if (HasAttrsFlag(state, stage_id, SearchPolicyNode::always_compute_inline_key) || IsStrictInlineable(task, state, stage->op) || - (IS_GPU(policy->cur_task_) && + (IS_GPU(policy->cur_task) && !ShouldBeCacheRead(policy, state, stage_id))) { return true; } @@ -367,7 +367,7 @@ class RuleSkipStage : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; const auto& attrs = stage->op->attrs; @@ -392,16 +392,16 @@ class RuleMultiLevelTiling : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; return NeedsMultilevelTiling(task, state, stage->op) ? - (IS_GPU(policy->cur_task_) ? kApplyAndSkipRest : kApply) : kPass; + (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; } std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); @@ -418,12 +418,12 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; int target_stage_id; - if (IS_GPU(policy->cur_task_)) { + if (IS_GPU(policy->cur_task)) { return NeedsMultilevelTiling(task, state, stage->op) && HasSingleElementwiseMatchedConsumer(task, state, stage, &target_stage_id) && @@ -440,9 +440,9 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); @@ -457,7 +457,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { base_state = DoMultiLevelTiling(base_state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids); std::vector follow_tiling_levels; - if (IS_GPU(policy->cur_task_)) { + if (IS_GPU(policy->cur_task)) { follow_tiling_levels.push_back(3); } else { follow_tiling_levels.push_back(1); @@ -487,7 +487,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; int target_stage_id; @@ -505,7 +505,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; State tmp_s = state; tmp_s.cache_write(stage_id, "local", task->compute_dag); @@ -526,7 +526,7 @@ class RuleAddCacheRead : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; std::unordered_set consumers; @@ -551,7 +551,7 @@ class RuleAddRfactor : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; return NeedsRfactor(task, state, stage->op) && @@ -561,7 +561,7 @@ class RuleAddRfactor : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; std::vector > ret; @@ -613,7 +613,7 @@ class RuleAddRfactor : public SketchGenerationRule { void MetaTileRewritePolicyNode::GenerateMetaSketch( std::vector* out_states) { - State init_state = cur_task_->compute_dag.GetInitState(); + State init_state = cur_task->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = GetStringParam(params, "cpu_multi_level_tiling_structure"); @@ -644,7 +644,7 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( sketch_rules.push_back(&rule_multi_level_tiling); sketch_rules.push_back(&rule_add_rfactor); sketch_rules.push_back(&rule_skip_stage); - if (IS_GPU(cur_task_)) { + if (IS_GPU(cur_task)) { // Try cache read first before cache write sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage); } @@ -705,7 +705,7 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( } } - StdCout(verbose_) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; + StdCout(verbose) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; } int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, @@ -728,7 +728,7 @@ int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, const std::vector >& candidate_lens = split_memo->GetFactorizationSchemes( extent, ps->lengths.size(), - policy->cur_task_->hardware_params->max_innermost_split_factor); + policy->cur_task->hardware_params->max_innermost_split_factor); StateNode* pstate = state->CopyOnWrite(); pstate->transform_steps[step_id] = SplitStepNode::make( @@ -771,11 +771,11 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, // Set default vthread=1 & threadIdx.x=default_warp_size // EvolutionarySearch will try more possiblity if (GetExtent(fused_it) <= - policy->cur_task_->hardware_params->warp_size) { + policy->cur_task->hardware_params->warp_size) { state->bind_thread(stage_id, fused_it, kThreadX); } else { const auto& split_its = state->split(stage_id, fused_it, - {1, policy->cur_task_->hardware_params->warp_size}); + {1, policy->cur_task->hardware_params->warp_size}); state->bind_thread(stage_id, split_its[0], kBlockX); state->bind_thread(stage_id, split_its[1], kVThread); state->bind_thread(stage_id, split_its[2], kThreadX); @@ -793,7 +793,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, } // TODO(..): Add ThreadBind support for rfactor - if (total_space_extent <= policy->cur_task_->hardware_params->warp_size) { + if (total_space_extent <= policy->cur_task->hardware_params->warp_size) { for (const auto& it : (*state)->stages[stage_id]->iters) { if (it->iter_type == kReduce) { break; @@ -828,7 +828,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, } const auto& vthread_it = state->fuse(stage_id, to_fuse); if (GetExtent(vthread_it) > - policy->cur_task_->hardware_params->max_vthread_extent) { + policy->cur_task->hardware_params->max_vthread_extent) { return -1; } state->bind_thread(stage_id, vthread_it, kVThread); @@ -844,7 +844,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, } const auto& threadidx_it = state->fuse(stage_id, to_fuse); if (GetExtent(threadidx_it) < - policy->cur_task_->hardware_params->warp_size) { + policy->cur_task->hardware_params->warp_size) { return -1; } state->bind_thread(stage_id, threadidx_it, kThreadX); @@ -876,7 +876,7 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, // Get spatial_split_step_ids from the root stage std::unordered_set consumers; std::vector spatial_split_step_ids; - GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); + GetConsumers(policy->cur_task, (*state), target_stage->op, &consumers); CHECK_EQ(consumers.size(), 1); int target_stage_id = OperationToStage(*consumers.begin(), (*state)); GetSpaceSplitStepIds((*state), target_stage_id, &spatial_split_step_ids); @@ -915,13 +915,13 @@ int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, continue; } - if (NeedsMultilevelTiling(policy->cur_task_, (*state), stage->op)) { + if (NeedsMultilevelTiling(policy->cur_task, (*state), stage->op)) { continue; } std::unordered_set consumers; - GetConsumers(policy->cur_task_, (*state), stage->op, &consumers); + GetConsumers(policy->cur_task, (*state), stage->op, &consumers); if (consumers.empty()) { continue; } @@ -1083,7 +1083,7 @@ int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, to_fuse.push_back(it); parallel_degree *= GetExtent(it); - if (parallel_degree > policy->cur_task_->hardware_params->num_cores * 16) { + if (parallel_degree > policy->cur_task->hardware_params->num_cores * 16) { break; } @@ -1135,7 +1135,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, } // Skip cooperative fetching stage - if (IS_GPU(policy->cur_task_) && + if (IS_GPU(policy->cur_task) && HasCacheReadStage((*state), stage_id - 1)) { continue; } @@ -1179,7 +1179,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, } cum_length_prod *= GetExtent(it); - if (cum_length_prod > policy->cur_task_->hardware_params->max_unroll_vec) { + if (cum_length_prod > policy->cur_task->hardware_params->max_unroll_vec) { break; } @@ -1278,13 +1278,13 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); - if (IS_GPU(cur_task_)) { - tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + if (IS_GPU(cur_task)) { + tmp_s = cur_task->compute_dag.InferBound(tmp_s); if (InitPopulationThreadBind(this, &tmp_s)) { continue_count++; if (continue_count == out_size) { - StdCout(verbose_) << "Initial Population Sampling..." << std::endl; + StdCout(verbose) << "Initial Population Sampling..." << std::endl; } continue; } @@ -1293,7 +1293,7 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m } else { InitPopulationChangeComputeLocation(this, &tmp_s, &rand_gen_); - tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + tmp_s = cur_task->compute_dag.InferBound(tmp_s); InitPopulationParallel(this, &tmp_s); } @@ -1305,8 +1305,8 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m out_states->push_back(std::move(tmp_s)); } - StdCout(verbose_) << "Sample Initial Population\t\t#s: " - << out_states->size() << std::endl; + StdCout(verbose) << "Sample Initial Population\t\t#s: " + << out_states->size() << std::endl; } void MetaTileRewritePolicyNode::EvolutionarySearch( @@ -1350,9 +1350,9 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( // Genetic Algorithm for (int k = 0; k < num_iters + 1; ++k) { // Maintain the heap - cur_task_->compute_dag.InferBound(pnow); + cur_task->compute_dag.InferBound(pnow); PruneUndefined(pnow); - cost_model->Predict(cur_task_, *pnow, &scores); + cost_model->Predict(cur_task, *pnow, &scores); for (size_t i = 0; i < pnow->size(); ++i) { const State& state = (*pnow)[i]; @@ -1379,10 +1379,10 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( } if (k % 5 == 0 || k == num_iters) { - StdCout(verbose_) << "GA Iter: " << k << std::fixed << std::setprecision(4) - << "\tMax score: " << max_score - << "\tMin score: " << heap.front().second - << "\tPop size: " << pnow->size() << std::endl; + StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4) + << "\tMax score: " << max_score + << "\tMin score: " << heap.front().second + << "\tPop size: " << pnow->size() << std::endl; } if (k == num_iters) { @@ -1431,7 +1431,7 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( if (rule_id == 0) { // Mutate Tile Size State tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, - cur_task_->hardware_params->max_innermost_split_factor); + cur_task->hardware_params->max_innermost_split_factor); if (tmp_s.defined()) { pnext->push_back(std::move(tmp_s)); } else { @@ -1463,9 +1463,9 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( double duration = std::chrono::duration_cast >( std::chrono::high_resolution_clock::now()- tic_begin).count(); - StdCout(verbose_) << "EvolutionarySearch\t\t#s: " << best_states->size() - << "\tTime elapsed: " - << std::fixed << std::setprecision(2) << duration << std::endl; + StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states->size() + << "\tTime elapsed: " + << std::fixed << std::setprecision(2) << duration << std::endl; } class RuleCustomSketch : public SketchGenerationRule { @@ -1519,7 +1519,7 @@ void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { auto meta_policy = dynamic_cast(policy); meta_policy->sketch_rules.emplace_back( new RuleCustomSketch(meet_condition_func, apply_func)); - StdCout(policy->verbose_) << "Custom sketch rule added." << std::endl; + StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; } TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 685052f3f71f..c9bccfdce806 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,18 +23,16 @@ */ #include "search_policy.h" - #include - #include "../serialization.h" namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode); +TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); -void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { +void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { LogReader reader = LogReaderNode::make(log_file); const auto& res = reader->ReadLines(-1); size_t log_size = res.first.size(); @@ -44,18 +42,18 @@ void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { std::vector measured_throughputs; for (size_t i = 0; i < log_size; i++) { const auto& inp = res.first[i]; - if (inp->task->workload_key == cur_task_->workload_key && + if (inp->task->workload_key == cur_task->workload_key && inp->task->target->target_name.compare( - cur_task_->target->target_name) == 0) { - State state = cur_task_->compute_dag.GetInitState(); + cur_task->target->target_name) == 0) { + State state = cur_task->compute_dag.GetInitState(); state.CopyOnWrite()->transform_steps = inp->state->transform_steps; - state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag); + state.DoSteps(inp->state->transform_steps, cur_task->compute_dag); measured_states.emplace_back(std::move(state)); measured_throughputs.push_back(res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); } } - cur_task_->compute_dag.InferBound(&measured_states); + cur_task->compute_dag.InferBound(&measured_states); for (size_t i = 0; i < measured_states.size(); i ++) { auto& state = measured_states[i]; const auto& state_str = state.ToStr(); @@ -68,33 +66,32 @@ void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { } } - StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size() - << " state hashes loaded from " << log_file - << " for " << cur_task_->workload_key << std::endl; + StdCout(verbose) << "Successfully load " << measured_states_set_.size() + << " measurement records from " << log_file + << " for " << cur_task->workload_key << std::endl; } else { - StdCout(verbose_) << "Measured States Set: no states found from " - << log_file << " for " << cur_task_->workload_key - << std::endl; + StdCout(verbose) << "No measurement records found in " + << log_file << " for " << cur_task->workload_key << std::endl; } } void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { - PrintTitle("Process search callbacks", verbose_); + PrintTitle("Call search callbacks", verbose); for (const auto& callback : callbacks) { callback->callback(this); } } } -SearchCallback PreLoadMeasuredStatesNode::make(std::string filename) { - auto node = make_object(); +SearchCallback PreloadMeasuredStatesNode::make(std::string filename) { + auto node = make_object(); node->filename = std::move(filename); return SearchCallback(node); } -void PreLoadMeasuredStatesNode::callback(SearchPolicyNode* policy) { - policy->PreLoadMeasuredStates(filename); +void PreloadMeasuredStatesNode::callback(SearchPolicyNode* policy) { + policy->PreloadMeasuredStates(filename); } // Search Policy @@ -103,8 +100,7 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") int verbose, ProgramMeasurer measurer) { Array inputs; Array results; - std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, - verbose, measurer); + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); return Array{inputs, results}; }); @@ -115,17 +111,17 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") .set_body_typed([](SearchPolicy policy, SearchTask task) { - policy->cur_task_ = task; + policy->cur_task = task; }); TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") .set_body_typed([](SearchPolicy policy, int verbose) { - policy->verbose_ = verbose; + policy->verbose = verbose; }); -TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStates") +TVM_REGISTER_GLOBAL("ansor.PreloadMeasuredStates") .set_body_typed([](std::string filename) { - return PreLoadMeasuredStatesNode::make(filename); + return PreloadMeasuredStatesNode::make(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 6085fd1816e8..f1f6f45fce9a 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -25,12 +25,12 @@ #ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#include "../search_task.h" #include #include #include #include #include -#include "../search_task.h" #include "../measure.h" namespace tvm { @@ -39,6 +39,7 @@ namespace ansor { class SearchPolicy; class SearchPolicyNode; +/*! Callback function to be called before or after the search process */ class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; @@ -47,7 +48,9 @@ class SearchCallbackNode : public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -class PreLoadMeasuredStatesNode : public SearchCallbackNode { +/*! \brief Preload measured states from a log file. + * This can resume the state of the search policy */ +class PreloadMeasuredStatesNode : public SearchCallbackNode { public: std::string filename; @@ -55,44 +58,48 @@ class PreLoadMeasuredStatesNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreLoadMeasuredStates"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreloadMeasuredStates"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); }; /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: + SearchTask cur_task; // The current task + int verbose; // Verbose level (0 means silent) + + void VisitAttrs(AttrVisitor* v) { + v->Visit("cur_task", &cur_task); + v->Visit("verbose", &verbose); + } + + // Search for a task virtual State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) = 0; + // Continue search one round for a task. + // This is used in the task scheduler for searching for multiple tasks together. virtual std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; - void PreLoadMeasuredStates(const std::string& log_file); - void RunCallbacks(const Array& callbacks); - - SearchTask cur_task_; // The current task - int verbose_; // Verbose level (0 means silent) + // Preload measured states from a log file to resume the state of the search policy + void PreloadMeasuredStates(const std::string& log_file); - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cur_task", &cur_task_); - } + // Run a list of callback functions + void RunCallbacks(const Array& callbacks); - // Dict keys + // Dict keys to give hints to the policy static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; static constexpr const char* always_unroll_key = "ansor_always_unroll"; static constexpr const char* no_split_at_inner_key = "ansor_no_split_at_inner"; static constexpr const char* no_split_at_outer_key = "ansor_no_split_at_outer"; - static constexpr const char* debug_skip_region_key = "ansor_debug_skip_region"; static constexpr const char* last_split_is_one_key = "ansor_last_split_is_one"; - - // Flag keys + // Flag keys to give hints to the policy static constexpr const char* always_compute_inline_key = "ansor_always_compute_inline"; static constexpr const char* no_cache_write_key = "ansor_no_cache_write"; static constexpr const char* no_cache_read_key = "ansor_no_cache_read"; - static constexpr const char* tensor_core_support_key = "ansor_tensor_core_support"; static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index e0fd00b23e7b..ba42ca55611c 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -62,7 +62,6 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia } } -// Apply multi-tiling structure according to a string format State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids) { std::vector > space_levels; @@ -187,8 +186,6 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo return tmp_s; } -// Apply tiling structure: space, space -// But use tile sizes from other SplitStep State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split) { if (n_split < 1 || n_split > 3) { @@ -280,7 +277,6 @@ State FollowTiling(const State& state, int stage_id, return tmp_s; } -// Randomly mutate the tile size of one SplitStep State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, std::mt19937* random_gen, int max_innermost_split_factor) { State tmp_s = old_state; @@ -382,7 +378,6 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split return State(); } -// Randomly mutate the value of one auto_unroll_max_step PragmaStep State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, const std::vector& auto_unroll_configs) { State tmp_s = old_state; @@ -411,170 +406,6 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen return tmp_s; } -// Mutate a parallel loop. -State MutataParallel(const State& state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, const SearchTask& task, int verbose) { - // To make this mutation simple but promising, we only focus on a specific case that - // parallel was added to the outermost loop and the loop is generated by fusing other loops. - // In short, we mutate the step pattern of (fuse -> parallel). - - // Extract all parallel steps. - std::vector parallel_steps; - for (size_t s = 0; s < state->transform_steps.size(); ++s) { - auto ps = state->transform_steps[s].as(); - if (!ps || ps->annotation != kParallel) { - continue; - } - parallel_steps.push_back(s); - } - if (parallel_steps.size() == 0) { - StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; - return State(); - } - - // Randomly pick one step. - int retry_ct = 0; - size_t step_id = 0; - size_t stage_id = 0; - do { - step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; - auto step = state->transform_steps[step_id].as(); - stage_id = step->stage_id; - - // Check assumptions. - auto iter_id = step->iter_id; - if (iter_id == 0 && step_id > 0 && state->transform_steps[step_id - 1].as()) { - break; - } - retry_ct++; - } while (retry_ct <= 3); - - if (retry_ct > 3) { - StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; - return State(); - } - - // 0: fuse less; 1: fuse more. - std::vector fuse_dir = {0.5, 1.0}; - - // The iter is an attached target so we can only fuse less. - if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0) { - fuse_dir[0] = 1.0; - } - - // Determine the fuse direction. - auto fuse_step = state->transform_steps[step_id - 1].as(); - std::vector fused_ids = fuse_step->fused_ids; - int iter_offset = 0; - if (RandomChoose(fuse_dir, random_gen) == 0) { - StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; - fused_ids.pop_back(); - iter_offset = 1; - } else { - StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; - fused_ids.push_back(fused_ids.back() + 1); - iter_offset = -1; - } - - // Replay a new state. - State tmp_s = task->compute_dag.GetInitState(); - for (size_t s = 0; s < state->transform_steps.size(); ++s) { - auto step = state->transform_steps[s]; - if (s == step_id - 1) { - step = FuseStepNode::make(step->stage_id, fused_ids); - } else if (s > step_id && step->stage_id == static_cast(stage_id)) { - // Since we change the loop structure, iter ID in later steps to the same stage - // has to be adjusted. - auto ps = step.as(); - if (ps) { - CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); - step = AnnotationStepNode::make(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); - } else { - StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" - << std::endl; - return State(); - } - } - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - tmp_s.DoStep(step, task->compute_dag); - } - return state; -} - -// Create all possible tile size states for all SplitStep -void GridMutateTileSize(const State& old_state, std::vector* cands, - SplitFactorizationMemo* split_memo, int max_innermost_split_factor) { - // Extract all SplitStep. - std::vector split_step_ids; - for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { - if (old_state->transform_steps[i]->IsInstance()) { - split_step_ids.push_back(i); - } - } - if (split_step_ids.empty()) { - return; - } - - // Move tile sizes and generate candidates. - for (size_t step_id : split_step_ids) { - const SplitStepNode* ps = old_state->transform_steps[step_id].as(); - CHECK(ps != nullptr); - - int extent = GetIntImm(ps->extent); - if (extent == 1) { - continue; - } - - // Get the current tile sizes. - std::vector lengths(ps->lengths.size(), 1); - for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { - lengths[i] = GetIntImm(ps->lengths[i]); - } - - const std::vector& const_factors = split_memo->GetFactors(extent); - CHECK_GE(const_factors.size(), 1); - - // Move tile size. - for (size_t i = 0; i < ps->lengths.size(); ++i) { - int old_length = lengths[i]; - - for (int factor : const_factors) { - if (i == ps->lengths.size() - 1 && factor > max_innermost_split_factor) { - // Limit the innermost factor. - break; - } - - // Make new length experssions and a new state. - std::vector length_exprs; - lengths[i] = factor; - int outermost = extent / ElementProduct(lengths); - if (outermost == 0) { - break; - } - - // std::cout << "Mutated extent " << extent << ": " << outermost; - for (size_t j = 0; j < lengths.size(); ++j) { - // std::cout << ", " << lengths[j]; - length_exprs.emplace_back(lengths[j]); - } - // std::cout << std::endl; - - State tmp_s = old_state; - const SplitStepNode* new_ps = tmp_s->transform_steps[step_id].as(); - auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = - SplitStepNode::make(new_ps->stage_id, new_ps->iter_id, new_ps->extent, length_exprs, - new_ps->inner_to_outer); - if (tmp_s.defined()) { - cands->push_back(std::move(tmp_s)); - } - } - lengths[i] = old_length; - } - } -} - -// Prune undefined states. void PruneUndefined(std::vector* states) { size_t pt = 0; for (size_t i = 0; i < states->size(); ++i) { diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 472e90771879..5f15397e7e90 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -464,14 +464,6 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, const std::vector& auto_unroll_configs); -// Mutate a parallel loop. -State MutataParallel(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, const SearchTask& task, int verbose = 0); - -// Create all possible tile size states for all SplitStep -void GridMutateTileSize(const State& old_state, std::vector* cands, - SplitFactorizationMemo* split_memo, int max_innermost_split_factor); - // GA: Crossover two states State CrossOverState(const State& p1, const State& p2); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index ed5d4b868c27..454305c04ef5 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "serialization.h" #include "loop_state.h" diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 8bd5eca7c93d..a8cd1d3c2462 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -153,6 +153,11 @@ class RelayBuildModule : public runtime::ModuleNode { CHECK_EQ(args.num_args, 2); *rv = this->Optimize(args[0], args[1], this->params_); }); + } else if (name == "call_all_topi_funcs") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue *rv) { + CHECK_EQ(args.num_args, 3); + this->CallAllTopiFuncs(args[0], args[1], args[2]); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -227,6 +232,21 @@ class RelayBuildModule : public runtime::ModuleNode { BuildRelay(mod, params_); } + /*! \brief Call all used TOPI compute and schedule in a relay function */ + void CallAllTopiFuncs(IRModule mod, + const TargetsMap& targets, + const tvm::Target& target_host) { + targets_ = targets; + target_host_ = target_host; + + IRModule relay_module = Optimize(mod, targets_, params_); + auto func = Downcast(relay_module->Lookup("main")); + + graph_codegen_ = std::unique_ptr(new GraphCodegen()); + graph_codegen_->Init(nullptr, targets_); + graph_codegen_->Codegen(func); + } + protected: /*! * \brief Optimize a Relay IRModule. @@ -287,7 +307,6 @@ class RelayBuildModule : public runtime::ModuleNode { // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); - //pass_seqs.push_back(transform::KernelLayoutTransform()); } // Fast math optimizations. diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h index b4b806c20e28..c82a96b30612 100644 --- a/src/relay/transforms/kernel_layout_transform.h +++ b/src/relay/transforms/kernel_layout_transform.h @@ -20,7 +20,8 @@ class KernelLayoutVisitor : public ExprVisitor { !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { ori_layouts_map[n] = global_ori_layouts_queue.front(); new_layouts_map[n] = global_new_layouts_queue.front(); - std::cout << "ori_layout " << global_ori_layouts_queue.front() << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; + // std::cout << "ori_layout " << global_ori_layouts_queue.front() + // << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; global_ori_layouts_queue.pop_front(); global_new_layouts_queue.pop_front(); } diff --git a/tests/python/unittest/test_ansor_relay_Integration.py b/tests/python/unittest/test_ansor_relay_integration.py similarity index 53% rename from tests/python/unittest/test_ansor_relay_Integration.py rename to tests/python/unittest/test_ansor_relay_integration.py index 9c423220844c..f3f424ab321b 100644 --- a/tests/python/unittest/test_ansor_relay_Integration.py +++ b/tests/python/unittest/test_ansor_relay_integration.py @@ -22,19 +22,18 @@ import tvm from tvm import ansor, relay import tvm.contrib.graph_runtime as runtime +from tvm.relay.testing import dqn -from test_ansor_common import get_tiled_matmul +def test_tune_dense_graph(): + def dense_graph(N, dtype="float32"): + ori_data = relay.var("data", shape=(N, N), dtype=dtype) + weight = relay.var("weight", shape=(N, N), dtype=dtype) + data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) + dense = relay.nn.dense(data, weight, out_dtype=dtype) + dense = relay.add(dense, weight) + dense = relay.nn.dense(dense, weight, out_dtype=dtype) + return ori_data, weight, dense -def dense_graph(N, dtype="float32"): - ori_data = relay.var("data", shape=(N, N), dtype=dtype) - weight = relay.var("weight", shape=(N, N), dtype=dtype) - data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) - dense = relay.nn.dense(data, weight, out_dtype=dtype) - dense = relay.add(dense, weight) - dense = relay.nn.dense(dense, weight, out_dtype=dtype) - return ori_data, weight, dense - -def test_dense_integration(): N = 128 data, weight, dense = dense_graph(N) mod = relay.Function([data, weight], dense) @@ -44,34 +43,23 @@ def test_dense_integration(): target = tvm.target.create("llvm") d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx) w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx) - workloads, wkl_weights = ansor.extract_from_program(mod, {}, target=target) + wkl_keys, wkl_weights = ansor.extract_from_program(mod, {}, target=target) - assert len(workloads) == 2 + assert len(wkl_keys) == 2 assert len(wkl_weights) == 2 tasks = [] - for wkl_key in workloads: + for wkl_key in wkl_keys: dag = ansor.workload_key_to_dag(wkl_key) tasks.append(ansor.SearchTask(dag, wkl_key, target)) - assert str(tasks[0].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ - "placeholder = PLACEHOLDER [128, 128]\n" + \ - "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ - "compute(y, x) += compute[y, x, kk]\n" - - assert str(tasks[1].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ - "placeholder = PLACEHOLDER [128, 128]\n" + \ - "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ - "compute(y, x) += compute[y, x, kk]\n" + \ - "T_add(ax0, ax1) = (compute[ax0, ax1] + placeholder[ax0, ax1])\n" - tuner = ansor.SimpleTaskScheduler(tasks) measure_ctx = ansor.LocalRPCMeasureContext() with tempfile.NamedTemporaryFile() as fp: - tuner.tune(ansor.TuneOption(n_trials=4, runner=measure_ctx.runner, + tuner.tune(ansor.TuneOption(n_trials=2, runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile(fp.name)])) with ansor.apply_history_best(fp.name): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target) @@ -80,8 +68,8 @@ def test_dense_integration(): m.set_input('weight', w) m.run() res = m.get_output(0) - if measure_ctx: - del measure_ctx + + del measure_ctx d = d.asnumpy() d = d * 2 @@ -92,5 +80,36 @@ def test_dense_integration(): tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5) + +def test_tune_dqn(): + mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') + target = tvm.target.create('llvm') + ctx = tvm.context("llvm") + + wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) + + tasks = [] + for wkl_key in wkl_keys: + dag = ansor.workload_key_to_dag(wkl_key) + tasks.append(ansor.SearchTask(dag, wkl_key, target)) + + assert len(tasks) == 5 + + tuner = ansor.SimpleTaskScheduler(tasks) + measure_ctx = ansor.LocalRPCMeasureContext() + with tempfile.NamedTemporaryFile() as fp: + tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, + measure_callbacks=[ansor.LogToFile('tmp.json')]), + search_policy='meta-rewrite.random') + with ansor.apply_history_best('tmp.json'): + ansor.prepare_layout_rewrite(mod, params, target) + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, opt_params = relay.build_module.build(mod, target=target) + ansor.finish_layout_rewrite() + + del measure_ctx + if __name__ == "__main__": - test_dense_integration() + test_tune_dense_graph() + test_tune_dqn() + diff --git a/topi/python/topi/ansor.py b/topi/python/topi/ansor.py deleted file mode 100644 index e821fd5bd42f..000000000000 --- a/topi/python/topi/ansor.py +++ /dev/null @@ -1,95 +0,0 @@ -"""All AutoSchedule Supported Operators""" -from __future__ import absolute_import as _abs -from tvm import ansor - -@ansor.register_topi_schedule() -def schedule_dense_nopack(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_nhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_NCHWc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_reduce(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_pool(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_adaptive_pool(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_softmax(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_nchw_int8(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_depthwise_conv2d_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_depthwise_conv2d_nhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_NCHWc_int8(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_depthwise_conv2d_NCHWc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_transpose_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv3d_ncdhw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv3d_ndhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv1d_ncw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv1d_nwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_dense_pack(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_batch_matmul(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_bitserial_conv2d_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_bitserial_conv2d_nhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_bitserial_dense(cfg, outs): - return ansor.gen_schedule(cfg, outs) diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index 0c0979763dba..e121fbc7ec6d 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -26,8 +26,3 @@ from .bitserial_dense import * from .injective import * from . import cortex_m7 - -import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") -if use_auto_scheduler.lower() == "true": - from ..ansor import * diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index d44fca8548d2..6171317cd80f 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -39,8 +39,3 @@ from .sort import * from .search import * from .image import * - -import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") -if use_auto_scheduler.lower() == "true": - from ..ansor import * diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index de02367a4dff..6800129c12aa 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm -from tvm import te +from tvm import te, ansor from .pad import pad from .util import get_pad_tuple @@ -342,23 +342,36 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - if len(Filter.shape) == 10: - kernel_h = Filter.shape[2] * Filter.shape[6] - kernel_w = Filter.shape[3] * Filter.shape[7] - channel = Filter.shape[4] * Filter.shape[8] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] - elif len(Filter.shape) == 11: - kernel_h = Filter.shape[3] * Filter.shape[7] - kernel_w = Filter.shape[4] * Filter.shape[8] - channel = Filter.shape[5] * Filter.shape[9] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] - elif len(Filter.shape) == 12: - kernel_h = Filter.shape[4] * Filter.shape[8] - kernel_w = Filter.shape[5] * Filter.shape[9] - channel = Filter.shape[6] * Filter.shape[10] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[3] * Filter.shape[7] * Filter.shape[11] + if ansor.GLOBAL_SCOPE.topi_in_compute_rewrite_mode: + # infer shape for the rewritten layout + if len(Filter.shape) >= 10: + # For cpu tile structure SSRSRS + base = len(Filter.shape) - 10 + kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] + kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] + channel = Filter.shape[4 + base] * Filter.shape[8 + base] + num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] + for i in range(base + 2): + num_filter *= Filter.shape[i] + elif len(Filter.shape) == 6: + # For cpu tile structure SRS + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] + kernel_h = Filter.shape[2] + kernel_w = Filter.shape[3] + channel = Filter.shape[4] + elif len(Filter.shape) == 5: + # For cpu tile structure SRS + num_filter = Filter.shape[0] * Filter.shape[4] + kernel_h = Filter.shape[1] + kernel_w = Filter.shape[2] + channel = Filter.shape[3] + elif len(Filter.shape) == 4: + num_filter, kernel_h, kernel_w, channel = Filter.shape + else: + raise ValueError("Don't know how to infer layout for filter shape: %s. " \ + "You can add a new branch for it to fix this." % str(Filter)) else: - kernel_h, kernel_w, channel, num_filter = Filter.shape + kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 28e9e862f4d8..659668cbbe4c 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -39,8 +39,3 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * - -import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") -if use_auto_scheduler.lower() == "true": - from ..ansor import * diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 14a6ee797276..437323d79791 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -124,7 +124,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # in the tuning logs. # :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. -# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# :code:`ansor.PreloadMeasuredStates` callback will load measured states # from history log before schedule search, we can add this callback to make # sure a same schedule will never be measured for multiple times. @@ -132,7 +132,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): tune_option = ansor.TuneOption(n_trials=20, runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) print("==== Get Lowered Stmt ====") diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index dfd36e89fd4c..08d5628ad8a2 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -148,7 +148,7 @@ def matmul_add(N, L, M, dtype): # you can do more trials according to your time budget. # :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. -# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# :code:`ansor.PreloadMeasuredStates` callback will load measured states # from history log before schedule search, we can add this callback to make # sure a same schedule will never be measured for multiple times. @@ -161,7 +161,7 @@ def matmul_add(N, L, M, dtype): tune_option = ansor.TuneOption(n_trials=5, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) ################################################################ # Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high From 0794875b61cea652fede1599b49dd64c81807ce5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 06:34:49 -0700 Subject: [PATCH 32/78] Fix xgb error & Simplify dispatcher (#35) --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/auto_schedule.py | 1 - python/tvm/ansor/compute_dag.py | 19 +- python/tvm/ansor/cost_model/cost_model.py | 5 +- python/tvm/ansor/cost_model/xgb_model.py | 12 +- python/tvm/ansor/dispatcher.py | 233 ++------------------ python/tvm/ansor/env.py | 18 ++ python/tvm/ansor/feature.py | 1 - python/tvm/ansor/measure.py | 8 +- python/tvm/ansor/serialization.py | 1 + python/tvm/ansor/task_scheduler.py | 5 +- python/tvm/ansor/workload_registry.py | 1 - src/ansor/serialization.cc | 3 +- tests/python/unittest/test_ansor_feature.py | 1 - 14 files changed, 70 insertions(+), 240 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 977e100e63c6..90a11820d159 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -40,7 +40,7 @@ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext, clear_fallback_cache, ApplyGraphBest + FallbackContext from .relay_integration import extract_from_program, extract_from_multiple_program, \ finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index acf8982d6e89..e8108a067b2e 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -97,7 +97,6 @@ class MetaTileRewritePolicy(SearchPolicy): seed: int Random seed """ - def __init__(self, program_cost_model, params=None, diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index f35c9d8221f3..6304c7bb0e0a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -53,6 +53,8 @@ def get_init_state(self): def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): """ + Apply transform steps according to the history of a state + Parameters ---------- state : StateObject @@ -68,6 +70,8 @@ def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel. def print_python_code_from_state(self, state): """ + Print transform steps in the history of a state as TVM's python schedule primitive + Parameters ---------- state : StateObject @@ -81,16 +85,29 @@ def print_python_code_from_state(self, state): def infer_bound_from_state(self, state): """ + Infer bound for a state + Parameters ---------- state : StateObject Returns ------- - state : StateObject + state : State """ state_obj = state if isinstance(state, StateObject) else state.state_object return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) def rewrite_layout_from_state(self, state: State): + """ + Rewrite the layout according to the transform steps in the history of a state + + Parameters + ---------- + state : StateObject + + Returns + ------- + state : StateObject + """ return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index 47ea5092b302..57cc53853b2e 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -26,18 +26,20 @@ @tvm._ffi.register_object("ansor.CostModel") class CostModel(Object): + """The base class for cost model""" pass @tvm._ffi.register_object("ansor.RandomModel") class RandomModel(Object): + """A model returns random estimation for all inputs""" def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) -# A random number generator func for c++'s RandomModel @tvm._ffi.register_func("ansor.cost_model.random_number") def random_number(n, return_ptr): + """ A random number generator func for c++'s RandomModel """ if n == 0: return return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) @@ -47,6 +49,7 @@ def random_number(n, return_ptr): @tvm._ffi.register_object("ansor.PythonBasedModel") class PythonBasedModel(CostModel): + """Base class for cost models implemented in python""" def __init__(self): def update_func(inputs, results): self.update(inputs, results) diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py index fce3f16d18ba..42af17daae2c 100644 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -16,16 +16,14 @@ # under the License. """Cost model based on xgboost""" -from typing import List import multiprocessing import logging -import time from collections import defaultdict import numpy as np import xgboost as xgb -from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve +from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve from .cost_model import PythonBasedModel from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states from ..serialization import LogReader @@ -65,8 +63,8 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): # todo(lmzheng): automatically decrease learning rate when the loss is too large 'n_gpus': 0, - 'n_threads': multiprocessing.cpu_count() / 2, - 'silent': 0, + 'nthread': multiprocessing.cpu_count() // 2, + 'verbosity': 0, 'seed': seed or 43, 'disable_default_eval_metric': 1 } @@ -180,7 +178,7 @@ def pack_sum_xgbmatrix_for_prediction(xs): x_flatten.append(row) pack_ids.append(ct) - return xgb.DMatrix(x_flatten), pack_ids + return xgb.DMatrix(np.array(x_flatten)), pack_ids def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): @@ -214,7 +212,7 @@ def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): y_flatten.append(y) pack_ids.append(ct) - ret = xgb.DMatrix(x_flatten, y_flatten) + ret = xgb.DMatrix(np.array(x_flatten), y_flatten) if weights is not None: ret.set_weight(weights_flatten) dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 0ef07197ea92..0c07fd141bd2 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -15,16 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Template dispatcher module. - -A dispatcher is a function that can contains multiple behaviors. -Its specific behavior is can be controlled by DispatchContext. - -DispatchContext is used in two ways, usually via different implementation -of the DispatchContext base class. - -- During search, we can use it to pass the current proposal from tuner. -- During evaluation, we can use it to set pick the best policy. +The global context that dispatches best configurations to workloads """ # pylint: disable=invalid-name @@ -33,9 +24,7 @@ import logging import numpy as np -from decorator import decorate -from tvm import target as _target from tvm.tir.expr import FloatImm logger = logging.getLogger('auto_scheduler') @@ -44,9 +33,6 @@ class DispatchContext(object): """ Base class of dispatch context. - - DispatchContext enables the target and workload - specific dispatch mechanism for templates. """ current = None @@ -55,7 +41,7 @@ def __init__(self): def query(self, target, workload): """ - Query the context to get the specific config for a template. + Query the context to get the specific config for a workload. If cannot find the result inside this context, this function will query it from the upper contexts. @@ -63,22 +49,20 @@ def query(self, target, workload): ---------- target: Target The current target - workload : Workload - The current workload. + workload : str + The current workload Returns ------- - cfg : State or str - The specific state for auto scheduler. + cfg : State + The schedule configuration for the workload """ ret = self._query_inside(target, workload) - #if ret is None: - # ret = self._old_ctx.query(target, workload) return ret def update(self, target, workload, cfg): """ - Update context with a specific config. + Update the config for a workload Parameters ---------- @@ -86,46 +70,14 @@ def update(self, target, workload, cfg): The current target workload : Workload The current workload. - cfg : State or str - The specific state for auto scheduler. - - Note - ---- - This interface is for cases when TVM decides to replace an operator in the graph. - For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW` - convolution with `NCHW[x]c` implementation on x86 CPUs. - Thus in TOPI, we first query schedule using original `NCHW` workload, - then update the dispatcher with the new `NCHW[x]c` workload. - So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using - its own workload directly. - - .. code-block:: python - - @conv2d_alter_layout.register("cpu") - def _alter_conv2d_layout(attrs, inputs, tinfo): - workload = get_conv2d_workload(...) - dispatch_ctx = auto_scheduler.DispatchContext.current - target = tvm.target.current_target() - config = dispatch_ctx.query(target, workload) - - # Get conv2d_NCHWc workload from config - # new_workload = ... - # new_inputs = ... - # new_attrs = ... - - # Store altered operator's config - dispatch_ctx.update(target, new_workload, config) - return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs) - - We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc` - share the same schedule parameters. - One can construct a new `State` if this is not the case. + cfg : State + The schedule configuration for the workload """ raise NotImplementedError() def _query_inside(self, target, workload): """ - Query the context to get the specific config for a template. + Query the context to get the specific config for a workload. This function only query config inside this context. Parameters @@ -138,7 +90,7 @@ def _query_inside(self, target, workload): Returns ------- cfg : State or str - The specific state for auto scheduler. + The schedule configuration for the workload """ raise NotImplementedError() @@ -151,78 +103,13 @@ def __exit__(self, ptype, value, trace): DispatchContext.current = self._old_ctx -def dispatcher(fworkload): - """Wrap a workload dispatcher function. - - Parameters - ---------- - fworkload : function - The workload extraction function from arguments. - - Returns - ------- - fdispatcher : function - A wrapped dispatcher function, which will - dispatch based on DispatchContext and - the current workload. - """ - dispatch_dict = {} - func_name = fworkload.__name__ - - def register(key, func=None, override=False): - """Register template function. - - Parameters - ---------- - key : str or List of str - The template key to identify the template - under this dispatcher. - func : function - The function to be registered. - The first argument of the function is always - cfg returned by DispatchContext, - the rest arguments are the same as the fworkload. - override : bool - Whether override existing registration. - - Returns - ------- - The register function if necessary. - """ - if isinstance(key, str): - key = [key] - - def _do_reg(myf): - for x in key: - if x in dispatch_dict and not override: - raise ValueError( - "Key %s is already registered for %s" % (x, func_name)) - dispatch_dict[x] = myf - return myf - - if func: - return _do_reg(func) - return _do_reg - - def dispatch_func(func, *args, **kwargs): - """The wrapped dispatch function""" - tgt = _target.current_target() - workload = func(*args, **kwargs) - cfg = DispatchContext.current.query(tgt, workload) - return dispatch_dict['direct'](cfg, *args, **kwargs) - - fdecorate = decorate(fworkload, dispatch_func) - fdecorate.register = register - return fdecorate - - class ApplyConfig(DispatchContext): - """Apply a deterministic config entity for all queries. + """Apply a deterministic config for all queries. Parameters ---------- config : State - The specific state for auto scheduler. + The schedule configuration """ def __init__(self, config): super(ApplyConfig, self).__init__() @@ -361,9 +248,7 @@ def update(self, target, workload, state): class FallbackContext(DispatchContext): """ A fallback dispatch context. - - Any tunable template can be called under this context. - This is the root context. + This is used as the root context. """ def __init__(self): @@ -387,7 +272,7 @@ def _query_inside(self, target, workload): logger.warning(msg) cfg = None - # cache this config + # cache this config to avoid duplicated warning message self.memory[key] = cfg return cfg @@ -412,91 +297,3 @@ def update(self, target, workload, cfg): DispatchContext.current = FallbackContext() - - -def clear_fallback_cache(target, workload): - """Clear fallback cache. Pass the same argument as _query_inside to this function - to clean the cache. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - - Note - ---- - This is used in alter_op_layout to clear the bad cache created before call topi compute function - """ - context = DispatchContext.current - while not isinstance(context, FallbackContext): - context = context._old_ctx - context.clear_cache(target, workload) - - -class ApplyGraphBest(DispatchContext): - """Load the graph level tuning optimal schedules. - - The input records should be in the ascending order of - node index for target operator. Usually this can be obtained - with graph tuner. - - This context maintains an internal counter to indicate the current - node index. - """ - def __init__(self, records): - """ - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - """ - from . import load_from_file - - super(ApplyGraphBest, self).__init__() - if isinstance(records, str): - records = load_from_file(records) - self._records = list(records) - self._counter = 0 - self._global_cfg_dict = {} - - def _query_inside(self, target, workload): - """ - Query the context to get config from records. - - Parameters - ---------- - target : Target - The current target - workload : Workload - The current workload. - - Returns - ------- - cfg : State or str - The specific state for auto scheduler. - """ - if self._counter < len(self._records): - cfg = self._records[self._counter][0].config - self._counter += 1 - self.update(target, workload, cfg) - return cfg - key = (str(target), workload) - if key not in self._global_cfg_dict: - msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \ - "A fallback configuration is used, which may bring great performance " \ - "regression." % (target, workload) - logger.warning(msg) - cfg = None - self._global_cfg_dict[key] = cfg - else: - cfg = self._global_cfg_dict[key] - return cfg - - def update(self, target, workload, cfg): - key = (str(target), workload) - self._global_cfg_dict[key] = cfg diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 9e44ad66048b..0f35f92acbbc 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -1,5 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """ The scope to store global variables in ansor """ + class AutoschedulerGlobalScope(object): def __init__(self): self.topi_in_compute_rewrite_mode = False diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index 9496533da6cc..d9f6d297f1af 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -17,7 +17,6 @@ """" Python API for Feature extraction. -The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` """ from typing import List, Tuple diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 3d9c33860cae..f00fe672505d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -230,7 +230,8 @@ def __init__(self, key, host, port, priority=1, class LocalRPCMeasureContext: - """ A context wrapper for RPCRunner. + """ A context wrapper for running RPCRunner locally. + This will launch a local RPC Tracker and local RPC Server. Parameters ---------- @@ -276,10 +277,10 @@ class MeasureErrorNo(object): """Error type for MeasureResult""" NO_ERROR = 0 # No error INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state - # Errors happen when compiling code on host (e.g. tvm.build) + # Errors happen when compiling code on host (e.g. tvm.build) COMPILE_HOST = 2 COMPILE_DEVICE = 3 # Errors happen when compiling code on device - # (e.g. OpenCL JIT on the device) + # (e.g. OpenCL JIT on the device) RUNTIME_DEVICE = 4 # Errors happen when run program on device WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output BUILD_TIMEOUT = 6 # Timeout during compilation @@ -288,6 +289,7 @@ class MeasureErrorNo(object): def make_error_msg(): + """Get the error message from traceback""" error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 97903b38bb0b..1bd9d8cf64e6 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -64,6 +64,7 @@ def __iter__(self): break yield ret[0], ret[1] # (input, result) + def load_from_file(filename: str): """Load measurement records from a file""" return zip(*LogReader(filename).read_lines()) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 89b4afd84e86..3d4d9624d7c2 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -147,13 +147,12 @@ def __init__(self, def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): """ Tune tasks. - Notice: This method does not have return value, make sure to set `LogToFile` - measure callback in `tune_option`. + Notice: This method does not have return value, make sure to set `LogToFile` + measure callback in `tune_option`. Parameters ---------- tune_option: TuneOption - search_policy: Str or List[SearchPolicy] """ # init members diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index fccdcf8864be..bcf8269b9490 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. - """ Workload registration and serialization. diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 454305c04ef5..2d8379f56a5f 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -55,7 +55,6 @@ template <> struct Handler > { inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Stage> & data) { - // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } @@ -456,7 +455,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ANSOR_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index bb19b84a970d..bcc7683b3f4a 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -148,4 +148,3 @@ def test_gpu_feature(): test_cpu_matmul() test_cpu_fusion() test_gpu_feature() - From a4c4548f2d1da651c8f13f8552e9cc9df2f167eb Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 08:58:41 -0700 Subject: [PATCH 33/78] Rename "MetaTileRewritePolicy" to "SketchPolicy". (#36) * Rename "MetaTileRewritePolicy" to "SketchPolicy". * Add a new class for auto_unroll_max_step, storage_offset in StageNode * fix tune_op_subgraph.py --- python/tvm/ansor/__init__.py | 6 +- python/tvm/ansor/auto_schedule.py | 28 ++-- python/tvm/ansor/relay_integration.py | 7 +- python/tvm/ansor/task_scheduler.py | 18 +-- python/tvm/ansor/workload_registry.py | 14 +- scripts/common.py | 38 ++--- scripts/shape_configs.py | 24 +-- scripts/tune_network.py | 137 ++++++++--------- scripts/tune_op_subgraph.py | 144 ++++++++---------- scripts/tune_test.py | 97 ++++++------ src/ansor/auto_schedule.cc | 2 +- src/ansor/compute_dag.cc | 3 +- src/ansor/loop_state.cc | 37 +++-- src/ansor/loop_state.h | 15 +- src/ansor/search_policy/search_policy.h | 1 + ...rite_policy.cc => sketch_search_policy.cc} | 132 ++++++++-------- ...ewrite_policy.h => sketch_search_policy.h} | 53 ++++--- tests/python/unittest/test_ansor_common.py | 2 +- .../unittest/test_ansor_relay_integration.py | 3 +- .../unittest/test_ansor_search_policy.py | 15 +- tutorials/ansor/tune_conv2d_cuda.py | 4 +- tutorials/ansor/tune_simple_subgraph.py | 4 +- 22 files changed, 386 insertions(+), 398 deletions(-) rename src/ansor/search_policy/{meta_tile_rewrite_policy.cc => sketch_search_policy.cc} (91%) rename src/ansor/search_policy/{meta_tile_rewrite_policy.h => sketch_search_policy.h} (66%) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 90a11820d159..c629c1049a87 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,14 +29,14 @@ # Shortcut from .compute_dag import ComputeDAG, LayoutRewriteLevel -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ - PreloadMeasuredStates, PreAddCustomRule, auto_schedule +from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \ + PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file -from .workload_registry import register_auto_scheduler_workload_func, \ +from .workload_registry import register_workload_func, \ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index e8108a067b2e..a03d9fdacbc2 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -83,17 +83,19 @@ def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) -@tvm._ffi.register_object("ansor.MetaTileRewritePolicy") -class MetaTileRewritePolicy(SearchPolicy): - """ The search policy that searches with meta tiling and random rewrite +@tvm._ffi.register_object("ansor.SketchSearchPolicy") +class SketchSearchPolicy(SearchPolicy): + """ The search policy that searches in a hierarchical search space defined by sketches. + The policy randomly samples programs from the space defined by sketches + and use evolutionary search to fine-tune them. Parameters ---------- program_cost_model: CostModel Cost model for programs params: int - Parameters of the search policy, go meta_tile_rewrite_policy.h to find the - definitions. See code below to find the default values + Parameters of the search policy. See `src/ansor/search_policy/sketch_search_policy.h` + to find the definitions. See code below to find the default values seed: int Random seed """ @@ -124,7 +126,7 @@ def __init__(self, params[key] = value self.__init_handle_by_constructor__( - _ffi_api.MetaTileRewritePolicy, program_cost_model, params, + _ffi_api.SketchSearchPolicy, program_cost_model, params, seed or random.randint(1, 1 << 30)) @@ -148,16 +150,16 @@ def __init__(self, filename: str): _ffi_api.PreloadMeasuredStates, filename) -@tvm._ffi.register_object("ansor.PreAddCustomRule") -class PreAddCustomRule(SearchCallback): +@tvm._ffi.register_object("ansor.PreloadCustomSketchRule") +class PreloadCustomSketchRule(SearchCallback): """ - A SearchCallback for MetaTileRewritePolicy that allowing users to add + A SearchCallback for SketchSearchPolicy that allowing users to add custom sketch rule. Notes ----- This is an advanced feature. Make sure you're clear how it - works and this should only be used in MetaTileRewritePolicy. + works and this should only be used in SketchSearchPolicy. Parameters ---------- @@ -168,7 +170,7 @@ class PreAddCustomRule(SearchCallback): """ def __init__(self, meet_condition_func, apply_func): self.__init_handle_by_constructor__( - _ffi_api.PreAddCustomRule, meet_condition_func, apply_func) + _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func) @tvm._ffi.register_object("ansor.TuneOption") @@ -197,7 +199,7 @@ class TuneOption(Object): Callback functions called before the search process Candidates: - ansor.PreloadMeasuredStates - - ansor.PreAddCustomRule + - ansor.PreloadCustomSketchRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, @@ -249,7 +251,7 @@ def auto_schedule(workload, target=None, """ if isinstance(search_policy, str): if search_policy == 'default': - search_policy = MetaTileRewritePolicy(RandomModel()) + search_policy = SketchSearchPolicy(RandomModel()) else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 85c4d8813f69..3c2eabd3dfac 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -28,7 +28,7 @@ from tvm import target, te, transform from tvm.te.tensor import PlaceholderOp, ComputeOp from .dispatcher import DispatchContext -from .workload_registry import register_auto_scheduler_workload_bufs, compute_dag_hash +from .workload_registry import register_workload_bufs, compute_dag_hash from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE @@ -203,11 +203,14 @@ def traverse(t): def auto_schedule_topi(outs): """ Use ansor to auto-schedule a topi compute declaration """ io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_auto_scheduler_workload_bufs(io_tensors) + key = register_workload_bufs(io_tensors) env = TracingEnvironment.current if env is None: # in the final build mode state = DispatchContext.current.query(target.Target.current(), key) + if state is None: + return te.create_schedule([x.op for x in outs]) + dag = ComputeDAG(io_tensors) # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, # Since kernel layout has already been rewritten in relay pass diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 3d4d9624d7c2..587fe3121e88 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -21,7 +21,7 @@ import numpy as np -from .auto_schedule import SearchTask, SearchPolicy, MetaTileRewritePolicy, TuneOption +from .auto_schedule import SearchTask, SearchPolicy, SketchSearchPolicy, TuneOption from .cost_model import RandomModel, XGBModel from .measure import ProgramMeasurer from .utils import array_mean, to_str_round @@ -42,7 +42,7 @@ def compute_score(self, costs: List[float]) -> float: def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], num_measure_per_iter, load_model_file=None, load_log_file=None): if search_policy == 'default': - search_policy = 'meta-rewrite.xgb' + search_policy = 'sketch.xgb' if isinstance(search_policy, str): policy_type, model_type = search_policy.split('.') @@ -58,16 +58,16 @@ def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: Li else: raise ValueError("Invalid search policy: " + search_policy) - if policy_type == 'meta-rewrite': - search_policies = [MetaTileRewritePolicy(cost_model) for _ in range(len(tasks))] + if policy_type == 'sketch': + search_policies = [SketchSearchPolicy(cost_model) for _ in range(len(tasks))] elif policy_type == 'limit-space': - search_policies = [MetaTileRewritePolicy(cost_model, - params={'cpu_multi_level_tiling_structure': 'SRS', - 'disable_change_compute_location': 1}) + search_policies = [SketchSearchPolicy(cost_model, + params={'cpu_multi_level_tiling_structure': 'SRS', + 'disable_change_compute_location': 1}) for _ in range(len(tasks))] elif policy_type == 'beam-search': - search_policies = [MetaTileRewritePolicy(cost_model, - params={'use_beam_search': 1}) + search_policies = [SketchSearchPolicy(cost_model, + params={'use_beam_search': 1}) for _ in range(len(tasks))] else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index bcf8269b9490..e706c0ec4cf9 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -42,19 +42,19 @@ WORKLOAD_FUNC_REGISTRY = {} -def register_auto_scheduler_workload_func(func: Callable): +def register_workload_func(func: Callable): """Register a workload generation function The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. Examples -------- - @register_auto_scheduler_workload_func + @register_workload_func def matmul(N, M, K): - A = tvm.placeholder((N, K), name='A') - B = tvm.placeholder((K, M), name='B') - k = tvm.reduce_axis((0, K), name='k') - C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] """ func_name = func.__name__ @@ -84,7 +84,7 @@ def compute_dag_hash(dag: ComputeDAG): return hashlib.md5(str_key).hexdigest() -def register_auto_scheduler_workload_bufs(bufs: List[Tensor]) -> str: +def register_workload_bufs(bufs: List[Tensor]) -> str: """Directly register buffers of a workload and return the workload_key The buffers can be looked up with workload_key_to_tensors by the workload_key """ diff --git a/scripts/common.py b/scripts/common.py index 84fbf8d6c731..8f4fbec09dd0 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -14,7 +14,7 @@ import tvm from tvm import te from tvm.ansor import (LogReader, make_workload_key_func, - register_auto_scheduler_workload_func, + register_workload_func, write_measure_records_to_file) from tvm.contrib import ndk, util @@ -22,28 +22,28 @@ ###################### Test Workloads #################### ############################################################ -@register_auto_scheduler_workload_func +@register_workload_func def min_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.min(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def argmin_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.argmin(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def softmax_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.nn.softmax(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def norm_bmn(B, M, N): A = te.placeholder((B, M, N), name='A') i = te.reduce_axis((0, M)) @@ -53,7 +53,7 @@ def norm_bmn(B, M, N): return [A, D] -@register_auto_scheduler_workload_func +@register_workload_func def add_mn(M, N): A = te.placeholder((M, N), name='A') B = te.placeholder((M, N), name='B') @@ -61,7 +61,7 @@ def add_mn(M, N): return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', tensor_core_support=False): A = te.placeholder((N, K), name='A', dtype=in_type) @@ -73,7 +73,7 @@ def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C', - attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) else: if not ((in_type == 'float16' and out_type == 'float32') or \ (in_type == 'int8' and out_type == 'int32')): @@ -82,11 +82,11 @@ def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), axis=[k]), name='C', - attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def dense_layer(batch, in_dim, out_dim): A = te.placeholder((batch, in_dim), name='A') B = te.placeholder((out_dim, in_dim), name='B') @@ -95,7 +95,7 @@ def dense_layer(batch, in_dim, out_dim): return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def max_pool_2d_nchw(N, C, H, W): data = te.placeholder((N, C, H, W), name='data') out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, @@ -103,7 +103,7 @@ def max_pool_2d_nchw(N, C, H, W): return [data, out] -@register_auto_scheduler_workload_func +@register_workload_func def add_min_relu(M, N): A = te.placeholder((M, N), name='A') B = te.placeholder((M, N), name='B') @@ -112,7 +112,7 @@ def add_min_relu(M, N): out = topi.nn.relu(D) return [A, B, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -123,7 +123,7 @@ def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation) return [data, kernel, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -190,7 +190,7 @@ def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, return Output -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, CO), name='kernel') @@ -199,7 +199,7 @@ def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dil out = topi.add(conv, bias) return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, 1), name='kernel') @@ -208,7 +208,7 @@ def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, pa out = topi.add(conv, bias) return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, CO), name='kernel') @@ -218,7 +218,7 @@ def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') @@ -243,7 +243,7 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation return [data, kernel, bias, bn_offset, bn_scale, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py index 95a1ba69634d..244638f5b29c 100644 --- a/scripts/shape_configs.py +++ b/scripts/shape_configs.py @@ -1,5 +1,5 @@ -""" Shape configurations for single operator evaluation -This file is shared by tune_all_single_op.py and scripts in baseline/ +""" Shape configurations for single operator / subgraph evaluation +This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ """ matmul_shapes = [ @@ -142,13 +142,6 @@ (1, 4096, 1024), ] -softmax_shapes = [ - (1, 1024), - (1, 4096), - (1, 16384), - (1, 65536), -] - single_op_shape_dict = { 'C1D': conv1d_shapes, 'C2D': conv2d_shapes, @@ -160,12 +153,11 @@ 'T2D': conv2d_transpose_shapes, 'CAP': conv2d_capsule_shapes, 'NRM': norm_shapes, - #'SMX': softmax_shapes, # The following workloads are not in our sinle op evaluation plan. # They should be moved to `common.py` and be used by `tune_wkl.py`. # 'C2D_NCHW': conv2d_nchw_shapes, - 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, +# 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, # 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, # 'GMM_TC': matmul_tensor_core_shapes, } @@ -192,19 +184,9 @@ (16, 128, 12, 128), ] - -batch_norm_shapes = [ - (16, 256), - (16, 1024), - (16, 4096), - (16, 16384), - (16, 65536), -] - subgraph_shape_dict = { "conv2d_bn_relu": conv2d_bn_relu_shapes, "transpose_batch_matmul": transpose_batch_matmul_shapes, - #"batch_norm": batch_norm_shapes, } resnet_shapes = [ diff --git a/scripts/tune_network.py b/scripts/tune_network.py index d4f1afd95572..1905d8132003 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -1,13 +1,12 @@ -"""Tune all workloads in a network""" +"""Tune a whole neural network""" import argparse import logging import random import os -import time import numpy as np import tvm -from tvm import _ffi, ansor, relay +from tvm import ansor, relay import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk @@ -20,8 +19,8 @@ dtype = "float32" -def get_network(name, model_path, batch_size, layout): - """Get the symbol definition and random weight of a network""" +def get_network(name, network_path, batch_size, layout): + """Get the relay module and random weights for a network""" input_shape = (batch_size, 3, 224, 224) output_shape = (batch_size, 1000) input_name = 'data' @@ -95,7 +94,7 @@ def get_network(name, model_path, batch_size, layout): input_shape = (1, 224, 224, 3) output_shape = (1, 1001) input_dtype = "float32" - tflite_model_buf = open(model_path, "rb").read() + tflite_model_buf = open(network_path, "rb").read() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) mod, params = relay.frontend.from_tflite(tflite_model, shape_dict={input_name: input_shape}, @@ -144,21 +143,17 @@ def get_network(name, model_path, batch_size, layout): def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, - local_measure, ndk_cc, device_key, host, port, run_timeout, num_threads, seed=43): - # Upload parameters to device + local_measure, ndk_cc, rpc_device_key, rpc_host, rpc_port, rpc_num_threads, seed=43): if local_measure: if target.target_name == "cuda": ctx = tvm.gpu() else: ctx = tvm.cpu() - if num_threads: - config_threadpool = _ffi.get_global_func('runtime.config_threadpool') - config_threadpool(0, num_threads) else: print("=============== Request Remote ===============") if 'TVM_NDK_CC' not in os.environ: os.environ['TVM_NDK_CC'] = ndk_cc - remote = request_remote(device_key, host, port, timeout=run_timeout) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) print("=============== Export ===============") ctx = remote.cpu() @@ -171,9 +166,10 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof print("=============== Load ===============") lib = remote.load_module("deploy_lib.so") - if num_threads: + + if rpc_num_threads: config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, num_threads) + config_threadpool(0, rpc_num_threads) np.random.seed(seed) data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) @@ -181,6 +177,7 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof module = debug_runtime.create(graph, lib, ctx) else: module = runtime.create(graph, lib, ctx) + if type(input_name) == list: for name in input_name: module.set_input(name, data_tvm) @@ -192,19 +189,20 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof return module, ctx -def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, - debug_profile, check_correctness, network_parameters, - task_scheduler_parameters, tune_parameters, module_parameters): - # Extract workloads from relay program - mod, params, input_name, data_shape, out_shape = get_network(**network_parameters) +def tune_and_evaluate(network_arguments, target, target_host, + search_policy, task_scheduler_arguments, tune_option_arguments, + tune, debug_profile, check_correctness, log_n_lines): + # Extract tasks from relay program + mod, params, input_name, data_shape, out_shape = get_network(**network_arguments) + # Tune all if tune: - print("=============== Extracting workloads ===============") + print("=============== Extract Workloads ===============") workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) - print("Totally %d workload extracted." % (len(workloads))) + print("Extract %d workloads in total" % (len(workloads))) # Tune workloads with auto scheduler - print("=============== Tuning ===============") + print("=============== Tune ===============") tasks = [] for i, wkl_key in enumerate(workloads): dag = ansor.workload_key_to_dag(wkl_key) @@ -212,24 +210,24 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) tuner = ansor.SimpleTaskScheduler(tasks, - lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), - **task_scheduler_parameters) - tune_option, measure_ctx = create_tune_option(target, **tune_parameters) + lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), + **task_scheduler_arguments) + tune_option, measure_ctx = create_tune_option(target, **tune_option_arguments) - if tune_parameters['local_measure'] and target.target_name != 'cuda': + if tune_option_arguments['local_measure'] and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" tuner.tune(tune_option, search_policy) if measure_ctx: del measure_ctx - kernel_layout_rewrite = False + kernel_layout_rewrite = True # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with ansor.apply_history_best(tune_parameters['log_file'], log_n_lines): + with ansor.apply_history_best(tune_option_arguments['log_file'], log_n_lines): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" - os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + if kernel_layout_rewrite: ansor.prepare_layout_rewrite(mod, target=target, params=params) else: @@ -245,12 +243,13 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, print("=============== Compile Finish ===============") module, ctx = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **module_parameters) + opt_params, debug_profile, **common_measure_parameters) # Evaluate print("========== Evaluate ==========") ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) prof_res = np.array(ftimer().results) + # display profile information if debug_profile or check_correctness: module.run() @@ -273,12 +272,12 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE target = tvm.target.create('llvm') - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) module, _ = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **module_parameters) + opt_params, debug_profile, **common_measure_parameters) module.run() expected_output = module.get_output(0).asnumpy() @@ -287,58 +286,58 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options + + # Search task related arguments parser.add_argument("--network", type=str, required=True) - parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") + parser.add_argument("--network-path", type=str, default=None, help="The path of tflite model") parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--layout", type=str, default='NHWC') parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='gradient', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--log-n-lines", type=int, help="Only load the first n lines for history log") parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--out-file", type=str, default='results.tsv') - parser.add_argument("--log-n-lines", type=int) - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=10) parser.add_argument("--early-stopping", type=int, default=-1) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) - parser.add_argument("--num-threads", type=int, default=None) args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) logging.basicConfig() logging.getLogger('ansor').setLevel(logging.DEBUG) + os.environ["TOPHUB_LOCATION"] = "NONE" # disable autotvm target = tvm.target.create(args.target) - log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, - target.target_name) + log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, + target.target_name) load_log_file = args.load_log or log_file search_policy = "%s.%s" % (args.policy, args.model_type) if args.layout: @@ -348,9 +347,9 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, else: layout = "NHWC" - network_parameters = { + network_arguments = { 'name': args.network, - 'model_path': args.model_path, + 'network_path': args.network_path, 'batch_size': args.batch_size, 'layout': layout } @@ -362,15 +361,16 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, 'verbose': args.verbose, } - control_parameters = { + common_measure_parameters = { 'local_measure': args.local_measure, - 'device_key': args.device_key, - 'host': args.host, - 'port': args.port, + 'rpc_device_key': args.rpc_device_key, + 'rpc_host': args.rpc_host, + 'rpc_port': args.rpc_port, + 'rpc_num_threads': args.rpc_num_threads, 'ndk_cc': args.ndk_cc, } - tune_parameters = { + tune_option_arguments = { 'log_file': log_file, 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, @@ -379,17 +379,10 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, 'build_timeout': args.build_timeout, 'run_timeout': args.run_timeout, 'early_stopping': args.early_stopping, - **control_parameters - } - - module_parameters = { - 'run_timeout': args.run_timeout, - 'num_threads': args.num_threads, - **control_parameters + **common_measure_parameters } - os.environ["TOPHUB_LOCATION"] = "NONE" - tune_and_evaluate(target, args.target_host, args.log_n_lines, search_policy, + tune_and_evaluate(network_arguments, target, args.target_host, + search_policy, task_scheduler_parameters, tune_option_arguments, args.tune, args.debug_profile, args.check_correctness, - network_parameters, task_scheduler_parameters, tune_parameters, - module_parameters) + args.log_n_lines) diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py index bf5cbe83c952..6574bb77e510 100644 --- a/scripts/tune_op_subgraph.py +++ b/scripts/tune_op_subgraph.py @@ -1,7 +1,6 @@ -"""Tune all operators for single op & subgraph evaluation""" +"""Tune all workloads for single op & subgraph evaluation""" import argparse import logging -import os import random import numpy as np @@ -12,14 +11,13 @@ from topi.nn.winograd_util import winograd_transform_matrices from topi.util import get_const_tuple -from common import measure_schedule, str2bool, \ - norm_bmn, softmax_mn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu +from common import measure_schedule, str2bool, norm_bmn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu from shape_configs import single_op_shape_dict, subgraph_shape_dict from tune_test import tune_workloads_jointly, replay_workload, create_tune_option # ========================== Single Ops ========================== -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def batch_matmul_nkkm(B, N, M, K): X = te.placeholder((B, N, K), name='A') Y = te.placeholder((B, K, M), name='B') @@ -27,7 +25,7 @@ def batch_matmul_nkkm(B, N, M, K): Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') return [X, Y, Z] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, L, CI), name='inputs') weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') @@ -49,7 +47,7 @@ def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, group ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, H, W, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') @@ -75,7 +73,7 @@ def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, g ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, CI, H, W), name='inputs') weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') @@ -101,7 +99,7 @@ def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, g ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, D, H, W, CI)) weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) @@ -131,7 +129,7 @@ def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation= ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): inputs = te.placeholder((N, H, W, C)) weight = te.placeholder((factor, kernel_size, kernel_size, C)) @@ -159,7 +157,7 @@ def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): inputs = te.placeholder((N, H, W, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') @@ -222,12 +220,12 @@ def _dilate(*indices): weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], axis=[rh, rw, rc]), name="conv2d_transpose_nhwc", - attrs={"auto_scheduler_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) + attrs={"ansor_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) # todo(lmzheng): add constraints on the tile size of h and w return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') @@ -254,7 +252,7 @@ def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, cap return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): # TODO: implement tile_size tile_size = 4 #_infer_tile_size(data, kernel) @@ -304,10 +302,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='data_pack', - attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_last_split_is_one": ["ci", "p"], - "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "ansor_last_split_is_one": ["ci", "p"], + "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], + "ansor_no_cache_write": "True", }) # do batch gemm @@ -323,10 +321,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse', - attrs={"auto_scheduler_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_last_split_is_one": ["co", "p"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], + "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], + "ansor_last_split_is_one": ["co", "p"], + "ansor_no_cache_write": "True", }) # output @@ -337,10 +335,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di co], name='conv2d_winograd', tag='conv2d_winograd_nhwc', - attrs={"auto_scheduler_no_split_at_outer": ["n", "h", "w", "co"],}) + attrs={"ansor_no_split_at_outer": ["n", "h", "w", "co"],}) return [inputs, kernel_pack, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): # TODO: implement tile_size tile_size = 4 #_infer_tile_size(data, kernel) @@ -390,10 +388,10 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='data_pack', - attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_split_at_outer": ["ci", "p"], - "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "ansor_no_split_at_outer": ["ci", "p"], + "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], + "ansor_no_cache_write": "True", }) # do batch gemm @@ -409,9 +407,9 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse', - attrs={"auto_scheduler_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], - "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True"}) + attrs={"ansor_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], + "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], + "ansor_no_cache_write": "True"}) # output output = te.compute((N, CO, H, W), lambda n, co, h, w: @@ -419,12 +417,12 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di idxmod(h, m), idxmod(w, m)], name='conv2d_winograd', - attrs={"auto_scheduler_no_split_at_outer": ["n", "co", "h", "w"],}) + attrs={"ansor_no_split_at_outer": ["n", "co", "h", "w"],}) return [inputs, kernel_pack, output] # ========================== Subgraphs ========================== -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def transpose_batch_matmul(batch, seq_len, n_head, n_dim): query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') @@ -433,23 +431,12 @@ def transpose_batch_matmul(batch, seq_len, n_head, n_dim): value_T = te.compute((batch, n_head, n_dim, seq_len), lambda b, h, d, l: value[b, l, h, d], name="value_T") k = te.reduce_axis((0, n_dim), name='k') - out = te.compute((batch, n_head, seq_len, seq_len), lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), name='C') + out = te.compute((batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), + name='C') return [query, value, out] -@ansor.register_auto_scheduler_workload_func -def batch_norm(M, N, eps=1e-5): - A = te.placeholder((M, N), name='A') - k1 = te.reduce_axis((0, M), name='k1') - k2 = te.reduce_axis((0, M), name='k2') - mean = te.compute((N,), lambda j: te.sum(A[k1][j] / M, axis=k1), name="mean") - var = te.compute((N,), - lambda j: te.sum((A[k2][j] - mean[j]) * (A[k2][j] - mean[j]) / (M - 1), k2), - name="var") - B = te.compute((M, N), lambda i, j: (A[i][j] - mean[j]) / te.sqrt(var[j] + eps), name='B') - - return [A, B] - -# ========================== Tune func & Dicts ========================== +# ========================== Tune function & Task dicts ========================== def tune_wkl(task_func_dict, shape_dict, wkl_type, args): target = tvm.target.create(args.target) @@ -464,8 +451,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if shape[0] == 1: shape = list(shape) shape[0] = args.batch_size - wkl_key = ansor.make_workload_key_func(func, shape) + wkl_key = ansor.make_workload_key_func(func, shape) wkl_keys.append(wkl_key) if args.fast_check: break @@ -473,9 +460,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if not args.tune: cost, gflops = replay_workload( wkl_key, target, args.target_host, log_file, - args.local_measure, args.device_key, args.host, - args.port, args.ndk_cc, False) - # TODO(): Add log record + args.local_measure, args.rpc_device_key, args.rpc_host, + args.rpc_port, args.rpc_num_threads, args.ndk_cc, False) # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', # workload_name, "AutoSchedule", "default", # {"costs": [cost]}, time.time()), args.out_file) @@ -489,7 +475,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): tune_option, measure_ctx = create_tune_option(target, log_file, n_trials, args.num_measure_per_iter, args.verbose, args.n_parallel, args.build_timeout, args.local_measure, - args.device_key, args.host, args.port, args.ndk_cc) + args.rpc_device_key, args.rpc_host, args.rpc_port, + args.rpc_num_threads, args.ndk_cc) # tune workloads jointly using JointTuner tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, @@ -516,7 +503,7 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): # The following workloads are not in our sinle op evaluation plan. # They should be moved to `common.py` and be used by `tune_wkl.py`. # 'C2D_NCHW': conv2d_nchw, - 'C2DWG_NHWC': conv2d_winograd_nhwc, +# 'C2DWG_NHWC': conv2d_winograd_nhwc, # 'C2DWG_NCHW': conv2d_winograd_nchw, # 'GMM_TC': matmul_nkkm, } @@ -529,44 +516,43 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options - parser.add_argument("--wkl", type=str, required=True, - help="all - For all workloads; \ - op - For all single ops; \ - subgraph - For all subgraphs; \ - Or specific wkl name") + # Search task related arguments + parser.add_argument("--wkl", type=str, required=True, + help="all - Tune all workloads; \ + op - Tune all single ops; \ + subgraph - Tune all subgraphs; \ + specific wkl name - Tune a specific workload") + parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials-per-shape", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--fast-check", action='store_true', help='Only run one shape for each workload. This is used for fast checking') - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials-per-shape", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='gradient', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') + parser.add_argument("--task-scheduler", type=str, default='round-robin', + choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--out-file", type=str, default='results.tsv') + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) args = parser.parse_args() diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 86f055caf889..67c0526dd624 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -13,8 +13,8 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, - n_parallel, build_timeout, local_measure, device_key, host, - port, ndk_cc, early_stopping=-1, run_timeout=10): + n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, + rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -27,8 +27,13 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=device_key, host=host, port=port, timeout=run_timeout, - n_parallel=n_parallel, repeat=1, min_repeat_ms=400) + runner = ansor.RPCRunner(key=rpc_device_key, host=rpc_host, port=rpc_port, + timeout=run_timeout, n_parallel=n_parallel, + repeat=1, min_repeat_ms=200) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) + if rpc_num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, rpc_num_threads) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, num_measure_per_iter=num_measure_per_iter, @@ -42,16 +47,17 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose def replay_workload(wkl_key, target, target_host, log_file, - local_measure=True, device_key=None, host="0.0.0.0", - port=9190, ndk_cc=None, show_lower_result=True): + local_measure=True, rpc_device_key=None, rpc_host="0.0.0.0", + rpc_port=9190, rpc_num_threads=None, ndk_cc=None, + show_lower_result=True): cost = gflops = None inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) if inp is None: - print("Cannot find log for: %s" % (wkl_key)) + print("Cannot find log for: %s" % wkl_key) else: dag = ansor.workload_key_to_dag(inp.task.workload_key) - print("Found schedule for: %s" % (wkl_key)) + print("Found schedule for: %s" % wkl_key) s, bufs = dag.apply_steps_from_state(inp.state) if show_lower_result: @@ -60,18 +66,21 @@ def replay_workload(wkl_key, target, target_host, log_file, if local_measure: remote = None else: - remote = request_remote(device_key, host, port, 1) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) + if rpc_num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, rpc_num_threads) - cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + cost = np.mean((measure_schedule(s, bufs, target, target_host, + remote=remote, ndk_cc=ndk_cc))) gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (gflops, cost * 1e3)) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % (gflops, cost * 1e3)) return cost, gflops -def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, - load_log_file, tune_option): +def tune_workload(wkl_key, target, target_host, policy, model_type, + load_model_file, load_log_file, tune_option): """Tune a workload""" if False: @@ -92,11 +101,11 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_f else: raise ValueError("Invalid model: " + model_type) - if policy == 'meta-rewrite': - policy = ansor.MetaTileRewritePolicy(program_cost_model=model) + if policy == 'sketch': + policy = ansor.SketchSearchPolicy(program_cost_model=model) elif policy == 'beam-search': - policy = ansor.MetaTileRewritePolicy(program_cost_model=model, - params={'use_beam_search': 1}) + policy = ansor.SketchSearchPolicy(program_cost_model=model, + params={'use_beam_search': 1}) else: raise ValueError("Invalid search policy: " + policy) @@ -105,12 +114,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_f search_policy=policy, tune_option=tune_option) - def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, search_policy, model_type, load_model_file, load_log_file, tune_option): - """Tune for multiple workloads jointly""" - + """Tune for multiple workloads together with TaksScheduler""" tasks = [] for wkl_key in wkl_keys: dag = ansor.workload_key_to_dag(wkl_key) @@ -127,36 +134,37 @@ def objective_func(costs): if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options + # Search task related arguments parser.add_argument("--wkl", type=str, required=True) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) args = parser.parse_args() @@ -170,14 +178,16 @@ def objective_func(costs): target = tvm.target.create(args.target) log_file = args.log_file or args.wkl + ".json" + # Tune workloads if args.tune: load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) tune_option, measure_ctx = create_tune_option(target, log_file, - args.n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.device_key, args.host, args.port, args.ndk_cc) + args.n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, + args.ndk_cc) if args.task_scheduler == 'no': # tune workloads one by one @@ -186,7 +196,7 @@ def objective_func(costs): args.model_type, args.load_model, load_log_file, tune_option) else: - # tune workloads jointly using JointTuner + # tune workloads jointly with TaskScheduler tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, target, args.target_host, args.policy, args.model_type, args.load_model, load_log_file, @@ -194,8 +204,9 @@ def objective_func(costs): if measure_ctx: del measure_ctx - if not args.tune or len(wkl_keys) == 1: + # Replay the best found schedule + if len(wkl_keys) == 1 or not args.tune: for wkl_key in wkl_keys: replay_workload(wkl_key, target, args.target_host, log_file, - args.local_measure, args.device_key, args.host, - args.port, args.ndk_cc) + args.local_measure, args.rpc_device_key, args.rpc_host, + args.rpc_port, args.rpc_num_threads, args.ndk_cc) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 200118cf708b..7ffc63a03917 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -26,7 +26,7 @@ #include #include #include -#include "search_policy/meta_tile_rewrite_policy.h" +#include "search_policy/sketch_search_policy.h" namespace tvm { namespace ansor { diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 6269b9f16f71..95e744a0e777 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1147,8 +1147,7 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } pstate->stages[i] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + std::move(new_iters), stage->compute_at, stage->attrs); } } diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 7569c91e3368..239f4e6988ac 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -76,35 +76,32 @@ Stage StageNode::make(te::Operation op) { node->compute_at = kRoot; node->op = std::move(op); - node->auto_unroll_max_step = 0; - node->storage_offset = 0; + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; return Stage(node); } Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset) { + ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = iters; node->compute_at = compute_at; - node->auto_unroll_max_step = auto_unroll_max_step; - node->storage_offset = storage_offset; + node->attrs = attrs; return Stage(node); } Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, - int auto_unroll_max_step, int storage_offset) { + StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = std::move(iters); node->compute_at = compute_at; - node->auto_unroll_max_step = auto_unroll_max_step; - node->storage_offset = storage_offset; + node->attrs = attrs; return Stage(node); } @@ -333,7 +330,7 @@ void State::DoReorderStep(const ReorderStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make( stage->op, stage->op_type, std::move(iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -400,7 +397,7 @@ std::vector State::DoSplitStepCommon( StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -494,7 +491,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -559,7 +556,7 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } @@ -581,7 +578,7 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); pstate->attach_map.DeleteStage(step->stage_id); } @@ -716,7 +713,7 @@ void State::DoPragmaStep(const PragmaStep& step) { StateNode* pstate = CopyOnWrite(); StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); size_t pos = step->pragma_type.find('$'); - stage->auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); + stage->attrs.auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); } else if (step->pragma_type == "tensor_core") { // Nothing needs to be done here } else { @@ -759,7 +756,7 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { void State::DoStorageAlignStep(const StorageAlignStep& step) { StateNode* pstate = CopyOnWrite(); StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - stage->storage_offset = step->offset; + stage->attrs.storage_offset = step->offset; } Iterator State::DoTensorizeStep(const TensorizeStep& step) { @@ -831,19 +828,19 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; - if (stage->auto_unroll_max_step != 0) { + if (stage->attrs.auto_unroll_max_step != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->func_name() - << " auto_unroll: " << stage->auto_unroll_max_step << "\n"; + << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } - if (stage->storage_offset != 0) { + if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->func_name() - << " storage_offset: " << stage->storage_offset << "\n"; + << " storage_offset: " << stage->attrs.storage_offset << "\n"; } size_t indent = 0; diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 6eef404ae272..31ed5274184d 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -121,6 +121,12 @@ class CacheReadStep; class CacheWriteStep; class PragmaStep; class RfactorStep; class StorageAlignStep; class TensorizeStep; +/*! \brief Stage-level attributes */ +struct StageAttributes { + int auto_unroll_max_step; + int storage_offset; +}; + /*! * \brief A stage in the compute declaration * Similar to te::Stage in `include/schedule.h` @@ -131,8 +137,7 @@ class StageNode : public Object { StageType op_type; std::vector iters; ComputeAtType compute_at; - int auto_unroll_max_step; - int storage_offset; + StageAttributes attrs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); @@ -141,12 +146,10 @@ class StageNode : public Object { static Stage make(te::Operation op); static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset); + ComputeAtType compute_at, StageAttributes attrs); static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset); + ComputeAtType compute_at, StageAttributes attrs); static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f1f6f45fce9a..4710cc05ae7f 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -43,6 +43,7 @@ class SearchPolicyNode; class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; + static constexpr const char *_type_key = "ansor.SearchCallback"; TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); }; diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc similarity index 91% rename from src/ansor/search_policy/meta_tile_rewrite_policy.cc rename to src/ansor/search_policy/sketch_search_policy.cc index 8b5b97224c08..7e4c3999dce3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -18,11 +18,13 @@ */ /*! - * \file ansor/search_policy/meta_tile_rewrite_policy.h - * \brief The search policy that searches by program sampling and evolutionary search + * \file ansor/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. */ -#include "meta_tile_rewrite_policy.h" +#include "sketch_search_policy.h" #include #include #include @@ -41,23 +43,23 @@ namespace tvm { namespace ansor { -TVM_REGISTER_NODE_TYPE(MetaTileRewritePolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreAddCustomRuleNode); +TVM_REGISTER_NODE_TYPE(SketchSearchPolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); // All possible candidates for auto_unroll -const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; +const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; -SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, +SearchPolicy SketchSearchPolicyNode::make(CostModel program_cost_model, Map params, int seed) { - auto node = make_object(); + auto node = make_object(); node->program_cost_model = std::move(program_cost_model); node->rand_gen_ = std::mt19937(seed); node->params = std::move(params); return SearchPolicy(node); } -State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, +State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { @@ -129,7 +131,7 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, } std::pair, Array > - MetaTileRewritePolicyNode::ContinueSearchOneRound( + SketchSearchPolicyNode::ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { if (cur_task.defined()) { CHECK_EQ(cur_task, task); @@ -176,7 +178,7 @@ std::pair, Array > return std::make_pair(std::move(inputs_arr), std::move(results_arr)); } -void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( +void SketchSearchPolicyNode::PickStatesWithEpsGreedy( std::vector* inputs, const std::vector& best_states, const std::vector& random_states, @@ -224,7 +226,7 @@ void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( } } -void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, +void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, int num_random_states, std::vector* random_states) { best_states->clear(); random_states->clear(); @@ -240,16 +242,16 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, num_use_measured = 0; } - // Synthesize meta structure - std::vector meta_structures; - GenerateMetaSketch(&meta_structures); + // Generate sketches + std::vector sketches; + GenerateSketch(&sketches); - // PrintAllStates(meta_structures); + // PrintAllStates(sketches); // exit(0); // Sample the init population std::vector init_population; - SampleInitPopulation(meta_structures, population - num_use_measured, &init_population); + SampleInitPopulation(sketches, population - num_use_measured, &init_population); // PrintAllStates(init_population); // exit(0); @@ -273,21 +275,21 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in meta sketch generation +// The baseclass of derivation rules used in sketch generation class SketchGenerationRule { public: enum ConditionEnum { kPass, kApply, kApplyAndSkipRest }; - virtual ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + virtual ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) = 0; - virtual std::vector > Apply(const MetaTileRewritePolicyNode* policy, + virtual std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) = 0; }; static inline bool ShouldBeCacheRead( - const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SketchSearchPolicyNode* policy, const State& state, int stage_id) { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -319,7 +321,7 @@ static inline bool ShouldBeCacheRead( } static inline bool ShouldAlwaysBeInlined( - const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SketchSearchPolicyNode* policy, const State& state, int stage_id) { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -348,13 +350,13 @@ static inline bool ShouldAlwaysBeInlined( // The rule that inlines simple elementwise ops class RuleAlwaysInline : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return ShouldAlwaysBeInlined(policy, state, stage_id) ? kApplyAndSkipRest : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { State tmp_s = state; tmp_s.compute_inline(stage_id); @@ -365,7 +367,7 @@ class RuleAlwaysInline : public SketchGenerationRule { // The rule that simply skip the current stage class RuleSkipStage : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -381,7 +383,7 @@ class RuleSkipStage : public SketchGenerationRule { return kApply; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return {std::make_pair(state, stage_id - 1)}; } @@ -390,7 +392,7 @@ class RuleSkipStage : public SketchGenerationRule { // The rule that performs multi-level tiling class RuleMultiLevelTiling : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -399,7 +401,7 @@ class RuleMultiLevelTiling : public SketchGenerationRule { (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : @@ -416,7 +418,7 @@ class RuleMultiLevelTiling : public SketchGenerationRule { // The rule that performs multi-level tiling and fuses later consumers class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -438,7 +440,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -485,7 +487,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { // The rule that adds a cache write stage class RuleAddCacheWrite : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -503,7 +505,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; @@ -518,13 +520,13 @@ class RuleAddCacheWrite : public SketchGenerationRule { // Currently only support 1 to 1 match cache read class RuleAddCacheRead : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return ShouldBeCacheRead(policy, state, stage_id) ? kApplyAndSkipRest : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -549,7 +551,7 @@ class RuleAddCacheRead : public SketchGenerationRule { // The rule that adds rfactor stage class RuleAddRfactor : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -559,7 +561,7 @@ class RuleAddRfactor : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -611,7 +613,7 @@ class RuleAddRfactor : public SketchGenerationRule { } }; -void MetaTileRewritePolicyNode::GenerateMetaSketch( +void SketchSearchPolicyNode::GenerateSketch( std::vector* out_states) { State init_state = cur_task->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = @@ -705,10 +707,10 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( } } - StdCout(verbose) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; + StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states->size() << std::endl; } -int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, +int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen, SplitFactorizationMemo* split_memo) { for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -741,7 +743,7 @@ int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, +int InitPopulationThreadBind(const SketchSearchPolicyNode* policy, State* state) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -853,7 +855,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, +int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, State* state) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { // Do cooperative fetching with cache read stage @@ -898,7 +900,7 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, +int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { if(GetIntParam(policy->params, "disable_change_compute_location")) { return 0; @@ -1060,12 +1062,12 @@ int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, +int InitPopulationParallel(const SketchSearchPolicyNode* policy, State* state) { - std::function annotate_parallel; + std::function annotate_parallel; annotate_parallel = [&annotate_parallel]( - const MetaTileRewritePolicyNode* policy, State* state, int stage_id, int iter_offset) { + const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { const Stage& stage = (*state)->stages[stage_id]; std::vector to_fuse; @@ -1125,7 +1127,7 @@ int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, +int InitPopulationVectorization(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -1202,7 +1204,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, +int InitPopulationUnroll(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -1266,7 +1268,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, return 0; } -void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& meta_structures, +void SketchSearchPolicyNode::SampleInitPopulation(const std::vector& sketches, int out_size, std::vector* out_states) { std::uniform_real_distribution<> dis(0.0, 1.0); int continue_count = 0; @@ -1274,7 +1276,7 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m // TODO(...): Maybe try muti thread here while (static_cast(out_states->size()) < out_size && continue_count < out_size * 10) { - State tmp_s = meta_structures[rand_gen_() % meta_structures.size()]; + State tmp_s = sketches[rand_gen_() % sketches.size()]; InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); @@ -1305,11 +1307,11 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m out_states->push_back(std::move(tmp_s)); } - StdCout(verbose) << "Sample Initial Population\t\t#s: " + StdCout(verbose) << "Sample Initial Population\t#s: " << out_states->size() << std::endl; } -void MetaTileRewritePolicyNode::EvolutionarySearch( +void SketchSearchPolicyNode::EvolutionarySearch( const std::vector& init_population, int num_best_states, std::vector* best_states) { auto tic_begin = std::chrono::high_resolution_clock::now(); @@ -1473,10 +1475,10 @@ class RuleCustomSketch : public SketchGenerationRule { RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} - inline ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + inline ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { auto ret = meet_condition_func_( - tvm::runtime::GetRef(policy), state, stage_id); + tvm::runtime::GetRef(policy), state, stage_id); if (ret.type_code() == 0) { return ConditionEnum(static_cast(ret)); } else { @@ -1485,12 +1487,12 @@ class RuleCustomSketch : public SketchGenerationRule { } inline std::vector > Apply( - const MetaTileRewritePolicyNode* policy, + const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { std::vector > ret; Array> apply_ret = apply_func_( - tvm::runtime::GetRef(policy), state, stage_id); + tvm::runtime::GetRef(policy), state, stage_id); for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); @@ -1506,32 +1508,32 @@ class RuleCustomSketch : public SketchGenerationRule { PackedFunc apply_func_; }; -SearchCallback PreAddCustomRuleNode::make(PackedFunc meet_condition_func, +SearchCallback PreloadCustomSketchRuleNode::make(PackedFunc meet_condition_func, PackedFunc apply_func) { - auto node = make_object(); + auto node = make_object(); node->meet_condition_func = meet_condition_func; node->apply_func = apply_func; return SearchCallback(node); } -void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { - CHECK(policy->IsInstance()); - auto meta_policy = dynamic_cast(policy); - meta_policy->sketch_rules.emplace_back( +void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto sketch_policy = dynamic_cast(policy); + sketch_policy->sketch_rules.emplace_back( new RuleCustomSketch(meet_condition_func, apply_func)); StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; } -TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") +TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") .set_body_typed([](CostModel program_cost_model, Map params, int seed){ - return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); + return SketchSearchPolicyNode::make(program_cost_model, params, seed); }); -TVM_REGISTER_GLOBAL("ansor.PreAddCustomRule") +TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreAddCustomRuleNode::make(meet_condition_func, apply_func); + return PreloadCustomSketchRuleNode::make(meet_condition_func, apply_func); }); } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/sketch_search_policy.h similarity index 66% rename from src/ansor/search_policy/meta_tile_rewrite_policy.h rename to src/ansor/search_policy/sketch_search_policy.h index 6930a71038a3..60920c5c1fdd 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/sketch_search_policy.h @@ -18,12 +18,14 @@ */ /*! - * \file ansor/search_policy/meta_tile_rewrite_policy.h - * \brief The search policy that searches by program sampling and evolutionary search + * \file ansor/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. */ -#ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#ifndef TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ #include #include @@ -40,12 +42,17 @@ namespace ansor { class SketchGenerationRule; -/*! Multi stage search policy */ -class MetaTileRewritePolicyNode: public SearchPolicyNode { +/*! + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. + */ +class SketchSearchPolicyNode: public SearchPolicyNode { public: + /*! \brief The cost model for complete programs */ CostModel program_cost_model; - /* this->params is used to store the following arguments + /*! \brief The parameters for search. It stores the following parameters: * int evolutionary_search_population // The population size for evolutionary search * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search * int evolutionary_search_num_iters; // The number of iterations for evolutionary search @@ -56,30 +63,33 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; + + /*! \brief The rules to generate sketches */ std::vector sketch_rules; static SearchPolicy make(CostModel program_cost_model, Map params, int seed); - // Search and make n_trails measurements - // Return the best state + /*! \brief Search and make n_trails measurements. + * \returns the best state */ State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) final; - // Continue search. This is used by JointTuner + /*! \brief Continue search for one round. This is used by JointTuner + * \returns the measurement pairs */ std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char *_type_key = "ansor.SketchSearchPolicy"; static const std::vector auto_unroll_configs; - TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SketchSearchPolicyNode, SearchPolicyNode); protected: - // Pick states from best states and random states with eps-greedy policy + /*! \brief Pick states from best states and random states with eps-greedy policy */ void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, const std::vector& random_states, int remaining_n_trials); @@ -89,11 +99,11 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { void SearchOneRound(std::vector* best_states, int num_random_states, std::vector* random_states); - // Synthesize meta tiling structure without tile size - void GenerateMetaSketch(std::vector* out_states); + // Generate sketches without tile size + void GenerateSketch(std::vector* out_states); // Sample init population - void SampleInitPopulation(const std::vector& meta_structures, + void SampleInitPopulation(const std::vector& sketches, int out_size, std::vector* out_states); // Perform evolutionary search @@ -104,9 +114,10 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration }; -TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(SketchSearchPolicy, SketchSearchPolicyNode); -class PreAddCustomRuleNode : public SearchCallbackNode { +/*! \brief Pre-search callback function to load custom rules for sketch generation */ +class PreloadCustomSketchRuleNode : public SearchCallbackNode { public: // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? PackedFunc meet_condition_func; @@ -117,11 +128,11 @@ class PreAddCustomRuleNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreAddCustomRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreAddCustomRuleNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); }; } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#endif // TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 083bd2721cb6..485679d6aa4e 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -21,7 +21,7 @@ import topi -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') diff --git a/tests/python/unittest/test_ansor_relay_integration.py b/tests/python/unittest/test_ansor_relay_integration.py index f3f424ab321b..1ad507e2f371 100644 --- a/tests/python/unittest/test_ansor_relay_integration.py +++ b/tests/python/unittest/test_ansor_relay_integration.py @@ -84,7 +84,6 @@ def dense_graph(N, dtype="float32"): def test_tune_dqn(): mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') target = tvm.target.create('llvm') - ctx = tvm.context("llvm") wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) @@ -100,7 +99,7 @@ def test_tune_dqn(): with tempfile.NamedTemporaryFile() as fp: tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile('tmp.json')]), - search_policy='meta-rewrite.random') + search_policy='sketch.random') with ansor.apply_history_best('tmp.json'): ansor.prepare_layout_rewrite(mod, params, target) with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 9b1716175b5a..deff561a4547 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -42,8 +42,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.MetaTileRewritePolicy(cost_model, params=params, - seed=seed) + search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) @@ -74,8 +73,8 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' def test_search_basic(): - # Ansor search process with local runner has some modification on thread - # binding, wrap this to a subprocess to eliminate the impacts to other tests + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) t.start() t.join() @@ -152,12 +151,12 @@ def apply_func2(meta_policy, state, stage_id): measure_ctx = ansor.LocalRPCMeasureContext() search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, - apply_func1)], + pre_search_callbacks=[ansor.PreloadCustomSketchRule( + meet_condition_func, apply_func1)], params={'disable_change_compute_location': 1}) search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, - apply_func2)], + pre_search_callbacks=[ansor.PreloadCustomSketchRule( + meet_condition_func, apply_func2)], params={'disable_change_compute_location': 1}) diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 437323d79791..03f1b24a768e 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -80,7 +80,7 @@ # recommended. # Use an extra function decorator to regist this workload -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -111,7 +111,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): seed = 0 random.seed(seed) cost_model = ansor.XGBModel(seed=seed) -search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) +search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) ######################################################################### # The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 08d5628ad8a2..00bef82cf855 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -142,7 +142,7 @@ def matmul_add(N, L, M, dtype): ################################################################ # Next, we choose random model and create a default search policy: -# :code:`ansor.MetaTileRewritePolicy`. +# :code:`ansor.SketchSearchPolicy`. # # We only make 5 trials in this tutorial for demonstration. In practice, # you can do more trials according to your time budget. @@ -157,7 +157,7 @@ def matmul_add(N, L, M, dtype): seed = 0 random.seed(seed) cost_model = ansor.RandomModel() -search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) +search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, measure_callbacks=[ansor.LogToFile(log_file)], From 593a2c7f43ed157ca1c1d0be04955e7b3ad9efcd Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 09:15:44 -0700 Subject: [PATCH 34/78] rebase --- src/ansor/compute_dag.cc | 24 +++++------ src/ansor/feature.cc | 2 +- src/ansor/loop_state.cc | 6 +-- src/ansor/transform_step.cc | 54 ++++++++++++------------ src/relay/op/tensor/transform.cc | 56 +++++++++++++++++++++++++ src/tir/transforms/unroll_loop.cc | 2 +- topi/include/topi/transform.h | 69 +++++++++++++++++++++++++++++++ 7 files changed, 169 insertions(+), 44 deletions(-) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 95e744a0e777..7b4857b34d76 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -863,9 +863,9 @@ void ComputeDAG::RewriteLayout( te::Operation new_placeholder_op; if (rewrite_placeholder) { new_placeholder_op = - te::PlaceholderOpNode::make(placeholder_op->name, - new_shape, - placeholder_op.as()->dtype); + te::PlaceholderOp(placeholder_op->name, + new_shape, + placeholder_op.as()->dtype); } else { new_placeholder_op = placeholder_op; } @@ -890,7 +890,7 @@ void ComputeDAG::RewriteLayout( } old_compute_op = op; CHECK(!new_compute_op.defined()); - new_compute_op = te::ComputeOpNode::make( + new_compute_op = te::ComputeOp( pop->name, pop->tag, pop->attrs, pop->axis, new_body); } } @@ -1028,8 +1028,8 @@ std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_st ss << ", "; } } - ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" - << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; + ss << " = " << "tuple(" << stage->op->name << ".op.axis)" + << " + " << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } @@ -1231,10 +1231,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (const auto& op : node->ops) { if (op->IsInstance()) { - ss << op->func_name() << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; } else if (auto pop = op.as()) { for (size_t k = 0; k < pop->body.size(); ++k) { - ss << op->func_name() << "("; + ss << op->name << "("; for (size_t i = 0; i < pop->axis.size(); i++) { ss << pop->axis[i]->var->name_hint; if (i != pop->axis.size() - 1) { @@ -1288,14 +1288,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Read from:\t"; for (const auto& pair : node->read_from.at(op)) { for (const auto& index : pair.second) { - p->stream << pair.first->func_name() << Array(index) << ", "; + p->stream << pair.first->name << Array(index) << ", "; } } p->stream << "\n"; p->stream << "Read by:\t"; for (const auto& pair : node->read_by.at(op)) { for (const auto& index : pair.second) { - p->stream << pair.first->func_name() << Array(index) << ", "; + p->stream << pair.first->name << Array(index) << ", "; } } p->stream << "\n"; @@ -1310,8 +1310,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (i == j) { continue; } if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { - p->stream << node->ops_topo_order[i]->func_name() << " -> " - << node->ops_topo_order[j]->func_name() << "\n"; + p->stream << node->ops_topo_order[i]->name << " -> " + << node->ops_topo_order[j]->name << "\n"; } } } diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 3c6976a0e25a..3b5849e22262 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -568,7 +568,7 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { is_gpu = true; // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = ForNode::make(var, 0, extent, ForType::Parallel, + Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body); outer_loop_prod *= extent; diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 239f4e6988ac..23e005503873 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -832,14 +832,14 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() + *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() + *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n"; } @@ -915,7 +915,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " = ...\n"; + *os << stage->op->name << " = ...\n"; } void PrintState(std::ostream* os, const StateNode* node, diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index b0e67a481ae3..857f3e570de0 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -63,7 +63,7 @@ std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, const te::Stage& stage = (*stages)[stage_id]; std::stringstream ss; - ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; + ss << "s[" << CleanName(stage->op->name) << "].reorder("; for (size_t i = 0; i < after_ids.size(); ++i) { ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); if (i != after_ids.size() - 1) { @@ -126,7 +126,7 @@ std::string PrintSplitAsPythonAPI(std::vector *stages, bool inner_to_outer) { te::Stage& stage = (*stages)[stage_id]; auto to_split = (*stage_to_axes)[stage][iter_id]; - const auto& func_name = CleanName(stage->op->func_name()); + const auto& func_name = CleanName(stage->op->name); const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); @@ -330,7 +330,7 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, const auto& fused = ApplyToSchedule(stages, stage_to_axes); ss << CleanName(fused->var->name_hint) << " = s[" - << CleanName(stage->op->func_name()) << "].fuse(" + << CleanName(stage->op->name) << "].fuse(" << to_fuse.str() << ")\n"; return ss.str(); @@ -385,7 +385,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; } - ss << "s[" << CleanName(stage->op->func_name()) << "]."; + ss << "s[" << CleanName(stage->op->name) << "]."; switch (annotation) { case kUnroll: ss << "unroll("; break; case kVectorize: ss << "vectorize("; break; @@ -417,7 +417,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, ss << ")\n"; if (bind_reduce_iter) { - ss << "s[" << CleanName(stage->op->func_name()) << "]" + ss << "s[" << CleanName(stage->op->name) << "]" << ".set_store_predicate(thread_x.var.equal(0))\n"; } @@ -450,8 +450,8 @@ std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, const auto& stage = (*stages)[stage_id]; const auto& target_stage = (*stages)[target_stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" - << CleanName(target_stage->op->func_name()) << "], " + ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" + << CleanName(target_stage->op->name) << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); ss << ")\n"; @@ -478,7 +478,7 @@ std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; + ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); @@ -504,7 +504,7 @@ std::string ComputeInlineStepNode::PrintAsPythonAPI( std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; + ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); @@ -551,12 +551,12 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - ss << CleanName(out->op->func_name()) << " = " - << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" + ss << CleanName(out->op->name) << " = " + << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", [" - << CleanName(reader_stages[0]->op->func_name()); + << CleanName(reader_stages[0]->op->name); for (size_t i = 1; i < reader_stage_ids.size(); ++i) { - ss << ", " << CleanName(reader_stages[i]->op->func_name()); + ss << ", " << CleanName(reader_stages[i]->op->name); } ss << "])\n"; @@ -567,7 +567,7 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) + ss << " = " << "tuple(" << CleanName(out->op->name) << ".op.axis)\n"; return ss.str(); @@ -615,7 +615,7 @@ std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()) << ", "; + ss << CleanName(outs[i]->op->name) << ", "; } ss << "= " << "s.cache_write([" << CleanName(stage->op.output(0)->op->name); @@ -632,9 +632,9 @@ std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) + ss << " = " << "tuple(" << CleanName(out->op->name) << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) + << " + " << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; } @@ -675,14 +675,14 @@ std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { size_t pos = pragma_type.find('$'); int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + ss << "s[" << CleanName(stage->op->name) << "].pragma(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + ss << "s[" << CleanName(stage->op->name) << "].pragma(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"unroll_explicit\", True)\n"; } else { - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + ss << "s[" << CleanName(stage->op->name) << "].pragma(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type << "\")\n"; } @@ -731,7 +731,7 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()); + ss << CleanName(outs[i]->op->name); if (i != outs.size() - 1) { ss << ", "; } @@ -749,9 +749,9 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) + ss << " = " << "tuple(" << CleanName(out->op->name) << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) + << " + " << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; } @@ -763,9 +763,9 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) + ss << " = " << "tuple(s[" << CleanName(output->op->name) << "].op.axis)" - << " + " << "tuple(s[" << CleanName(output->op->func_name()) + << " + " << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n"; return ss.str(); @@ -794,7 +794,7 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( te::Schedule *schedule, const std::vector& transform_steps) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" + ss << "s[" << CleanName(stage->op->name) << "].storage_align(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", " << offset << ")\n"; @@ -829,7 +829,7 @@ std::string TensorizeStepNode::PrintAsPythonAPI( te::Schedule *schedule, const std::vector& transform_steps) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].tensorize(" + ss << "s[" << CleanName(stage->op->name) << "].tensorize(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << ti_func_name << "())\n"; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ee5e291e3d53..18ace14a0b75 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2455,6 +2455,62 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); +// relay.kernel_layout_transform +TVM_REGISTER_NODE_TYPE(KernelLayoutTransformAttrs); + +Array KernelLayoutTransformCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type) { + //const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{ + topi::kernel_layout_transform(inputs[0], param->src_layout, param->dst_layout) + }; +} + +bool KernelLayoutTransformRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + + const auto* data = types[0].as(); + CHECK(data != nullptr); + const KernelLayoutTransformAttrs* params = attrs.as(); + + Array dst_shape; + std::vector dst_axes; + + topi::parse_kernel_layout(params->dst_layout, &dst_shape, &dst_axes); + + reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); + return true; +} + +Expr MakeKernelLayoutTransform(Expr data, + String src_layout, + String dst_layout) { + auto attrs = make_object(); + attrs->src_layout = std::move(src_layout); + attrs->dst_layout = std::move(dst_layout); + static const Op& op = Op::Get("kernel_layout_transform"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.kernel_layout_transform") +.set_body_typed(MakeKernelLayoutTransform); + +RELAY_REGISTER_OP("kernel_layout_transform") + .describe(R"code(Transform the input kernel layout. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("kernel_layout_transform", KernelLayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", KernelLayoutTransformCompute); + + /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 1c84304fb0e7..3876d67b7b11 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -173,7 +173,7 @@ class LoopUnroller : public StmtExprMutator { if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { // Do not unroll too long loops ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; - return ForNode::make(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); } Stmt body = op->body; Map vmap; diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e0e455667889..7dd782f5b622 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1295,6 +1295,75 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } +/*! + * \brief utility function for kernel_layout_transform + */ +inline void parse_kernel_layout(const String& layout, + Array* shape, + std::vector* axes) { + int32_t factor = 0; + std::string axis = ""; + for (char c : std::string(layout)) { + if (c >= 'A' && c <= 'z') { + axis += c; + if (factor != 0) { + shape->push_back(factor); + factor = 0; + } + } else if (c >= '0' && c <= '9') { + factor = factor * 10 + c - '0'; + if (!axis.empty()) { + axes->push_back(axis); + axis = ""; + } + } else { + LOG(FATAL) << "Invalid layout " << layout; + } + } + if (!axis.empty()) { + axes->push_back(axis); + } +} + +/*! + * \brief Transform the kernel layout according to \p src_layout and \p dst_layout + * \param src the source input. + * \param src_layout the source layout. + * \param dst_layout the destination layout. + * \param name output tensor name. + * \param tag output tensor tag. + * \return A tensor with shape in \p dst_layout + */ +inline Tensor kernel_layout_transform(const Tensor& src, + const String& src_layout, + const String& dst_layout, + const String name = "T_kernel_layout_trans", + const String tag = kInjective) { + Array src_shape; + std::vector src_axes; + Array dst_shape; + std::vector dst_axes; + + parse_kernel_layout(src_layout, &src_shape, &src_axes); + parse_kernel_layout(dst_layout, &dst_shape, &dst_axes); + return compute( + dst_shape, [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices; + for (const std::string& src_axis : src_axes) { + PrimExpr src_index = 0; + CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); + for (size_t i = 0; i < dst_axes.size(); ++i) { + if (dst_axes[i] == src_axis) { + src_index = src_index * dst_shape[i] + dst_indices_expr[i]; + } + } + src_indices.push_back(src_index); + } + return src(src_indices); + }, name, tag); +} + /*! * \brief Get the shape of input tensor. * \param src the input tensor. From 53bd591167959a5ae0d85ca27988a826e73c8dcc Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 22 Jun 2020 15:22:23 +0800 Subject: [PATCH 35/78] Migrate all node::make to noderef's construct function (#37) * Start to move xxxnode::make to noderef() * Update * Update * Finish transform_step * Finish comute dag & auto schedule * Update * Update * Update * Update * Update * Code refine * Code refine * Code refine * Update * Update --- src/ansor/auto_schedule.cc | 26 +- src/ansor/auto_schedule.h | 22 +- src/ansor/compute_dag.cc | 39 ++- src/ansor/compute_dag.h | 20 +- src/ansor/cost_model/cost_model.cc | 33 ++- src/ansor/cost_model/cost_model.h | 67 ++++- src/ansor/feature.cc | 18 +- src/ansor/loop_state.cc | 182 ++++++------- src/ansor/loop_state.h | 125 +++++---- src/ansor/measure.cc | 92 ++++--- src/ansor/measure.h | 133 ++++++--- src/ansor/search_policy/search_policy.cc | 8 +- src/ansor/search_policy/search_policy.h | 17 +- .../search_policy/sketch_search_policy.cc | 34 ++- .../search_policy/sketch_search_policy.h | 43 ++- src/ansor/search_policy/utils.cc | 4 +- src/ansor/search_task.cc | 35 ++- src/ansor/search_task.h | 39 ++- src/ansor/serialization.cc | 54 ++-- src/ansor/serialization.h | 28 +- src/ansor/transform_step.cc | 92 ++++--- src/ansor/transform_step.h | 257 +++++++++++++----- tests/cpp/ansor_test.cc | 2 +- 23 files changed, 839 insertions(+), 531 deletions(-) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 7ffc63a03917..05cb95c2c451 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -33,11 +33,10 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(TuneOptionNode); -TuneOption TuneOptionNode::make(int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, - Builder builder, Runner runner, - Array measure_callbacks, - Array pre_search_callbacks) { +TuneOption::TuneOption(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { auto node = make_object(); node->n_trials = n_trials; node->early_stopping = early_stopping; @@ -47,16 +46,16 @@ TuneOption TuneOptionNode::make(int n_trials, int early_stopping, node->runner = std::move(runner); node->measure_callbacks = std::move(measure_callbacks); node->pre_search_callbacks = std::move(pre_search_callbacks); - return TuneOption(node); + data_ = std::move(node); } std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { // Search for the best schedule ProgramMeasurer measurer = - ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, - tune_option->measure_callbacks, - tune_option->verbose); + ProgramMeasurer(tune_option->builder, tune_option->runner, + tune_option->measure_callbacks, + tune_option->verbose); State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, @@ -70,8 +69,8 @@ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, SearchPolicy search_policy, HardwareParams hardware_params, TuneOption tune_option) { - ComputeDAG dag = ComputeDAGNode::make_by_workload_key(workload_key); - SearchTask task = SearchTaskNode::make( + ComputeDAG dag = ComputeDAG(workload_key); + SearchTask task = SearchTask( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); return AutoSchedule(std::move(task), std::move(search_policy), @@ -83,9 +82,8 @@ TVM_REGISTER_GLOBAL("ansor.TuneOption") int num_measure_per_iter, int verbose, Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks) { - return TuneOptionNode::make(n_trials, early_stopping, - num_measure_per_iter, verbose, builder, - runner, measure_callbacks, pre_search_callbacks); + return TuneOption(n_trials, early_stopping, num_measure_per_iter, verbose, + builder, runner, measure_callbacks, pre_search_callbacks); }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 4e70ac0b577a..f17c043cfadd 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -34,7 +34,6 @@ namespace tvm { namespace ansor { /*! \brief Tuning and measurement options */ -class TuneOption; class TuneOptionNode : public Object { public: int n_trials; // Number of total measurement trials @@ -61,15 +60,24 @@ class TuneOptionNode : public Object { v->Visit("pre_search_callbacks", &pre_search_callbacks); } - static TuneOption make(int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array measure_callbacks, - Array pre_search_callbacks); - static constexpr const char* _type_key = "ansor.TuneOption"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(TuneOption, ObjectRef, TuneOptionNode); + +/*! + * \brief Managed reference to TuneOptionNode. + * \sa TuneOptionNode + */ +class TuneOption : public ObjectRef { + public: + TuneOption(int n_trials, int early_stopping, int num_measure_per_iter, + int verbose, Builder builder, Runner runner, + Array measure_callbacks, + Array pre_search_callbacks); + + TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TuneOptionNode); +}; /*! \brief Auto schedule for a compute declaration */ std::pair > AutoSchedule( diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 7b4857b34d76..13f64b2bdc89 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -241,7 +241,7 @@ static bool HasExpensiveOp(const PrimExpr& expr) { return found; } -AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { +AccessAnalyzer::AccessAnalyzer(const Array& tensors) { auto node = make_object(); OperationMap has_branch; @@ -290,8 +290,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { for (const auto& pair : node->read_from[op]) { const std::vector >& access = pair.second; for (const auto& index : access) { - if (!IsInjective(op, index, &axis_missing, &axis_duplicated, - &same_order)) { + if (!ansor::IsInjective(op, index, &axis_missing, &axis_duplicated, + &same_order)) { is_injective = false; is_strict_inlineable = false; break; @@ -356,7 +356,7 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } } - return AccessAnalyzer(node); + data_ = std::move(node); } bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation &op) const { @@ -554,7 +554,6 @@ class FlopEstimator: public ExprFunctor { return ret; } - double VisitExprDefault_(const Object* op) final { fail = true; return -1.0; @@ -567,20 +566,20 @@ State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } -ComputeDAG ComputeDAGNode::make(Array tensors) { +ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; node->tensors = std::move(tensors); - node->access_analyzer = AccessAnalyzerNode::make(node->tensors); + node->access_analyzer = AccessAnalyzer(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); - node->init_state = StateNode::make(node->ops); + node->init_state = State(node->ops); - return ComputeDAG(node); + data_ = std::move(node); } -ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) { +ComputeDAG::ComputeDAG(const std::string& workload_key) { Array tens; // Call python function to decode the workload_key and get the I/O tensors if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { @@ -588,7 +587,7 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) } else { LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; } - return ComputeDAGNode::make(std::move(tens)); + ComputeDAG(std::move(tens)); } std::string BaseName(const std::string& str) { @@ -938,7 +937,7 @@ void ComputeDAG::RewriteLayout( } } - pdag->init_state = StateNode::make(pdag->ops); + pdag->init_state = State(pdag->ops); Array old_tensors = pdag->tensors; ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); @@ -1105,7 +1104,7 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, } } - *task_dag = ComputeDAGNode::make(new_tensors); + *task_dag = ComputeDAG(new_tensors); } @@ -1136,18 +1135,16 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { auto find_res = bounds.find(axis); if (find_res != bounds.end()) { - new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, - iter->iter_type, - iter->annotation, - &iter->ori_iters, - iter->attr)); + new_iters.push_back(Iterator(iter->name, (*find_res).second, + iter->iter_type, iter->annotation, + &iter->ori_iters, iter->attr)); } else { LOG(FATAL) << "Infer bound fails"; } } - pstate->stages[i] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->attrs); + pstate->stages[i] = Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs); } } @@ -1319,7 +1316,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("ansor.ComputeDAG") .set_body_typed([](Array tensors) { - return ComputeDAGNode::make(tensors); + return ComputeDAG(tensors); }); TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 8da71f005f19..b1b60e678904 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -37,7 +37,6 @@ namespace tvm { namespace ansor { -class ComputeDAG; class AccessAnalyzer; class StateNode; class State; class Step; /*! \brief Read/Write access static analysis result */ @@ -54,15 +53,17 @@ class AccessAnalyzerNode : public Object { OperationMap is_output; std::vector ops_topo_order; - static AccessAnalyzer make(const Array& tensors); - static constexpr const char* _type_key = "ansor.AccessAnalyzer"; TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); }; -/*! \brief Read/Write access static analysis result */ +/*! + * \brief Managed reference to AccessAnalyzerNode. + * \sa AccessAnalyzerNode + */ class AccessAnalyzer : public ObjectRef { public: + explicit AccessAnalyzer(const Array& tensors); // read/write access analysis bool NeedsMultiLevelTiling(const te::Operation& op) const; bool IsInjective(const te::Operation& op) const; @@ -121,9 +122,6 @@ class ComputeDAGNode : public Object { v->Visit("access_analyzer", &access_analyzer); } - static ComputeDAG make(Array tensors); - static ComputeDAG make_by_workload_key(const std::string& workload_key); - static constexpr const char* _type_key = "ansor.ComputeDAG"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); }; @@ -135,9 +133,15 @@ enum LayoutRewriteLevel { kBothRewrite = 3, // Rewrite both placeholder and compute body in the compute dag }; -/*! \brief Compute declaration graph */ +/*! + * \brief Managed reference to ComputeDAGNode. + * \sa ComputeDAGNode + */ class ComputeDAG: public ObjectRef { public: + explicit ComputeDAG(Array tensors); + explicit ComputeDAG(const std::string& workload_key); + // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. // The return values can be used as arguments to tvm.build or tvm.lower std::pair > ApplySteps( diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index bbf15a241974..ee7bf8b26053 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -48,7 +48,7 @@ void RandomNumber(TVMArgs args, TVMRetValue* rv) { } } -CostModel RandomModelNode::make() { +RandomModel::RandomModel() { ObjectPtr node = make_object(); node->random_number_func = runtime::Registry::Get("ansor.cost_model.random_number"); @@ -58,7 +58,7 @@ CostModel RandomModelNode::make() { static PackedFunc cost_model_random_number(RandomNumber); node->random_number_func = &cost_model_random_number; } - return CostModel(node); + data_ = std::move(node); } void RandomModelNode::Update(const Array& inputs, @@ -71,11 +71,11 @@ void RandomModelNode::Predict(const SearchTask& task, (*random_number_func)(states.size(), static_cast(scores->data())); } -CostModel MeasureModelNode::make(Builder builder, Runner runner) { +MeasureModel::MeasureModel(Builder builder, Runner runner) { ObjectPtr node = make_object(); - node->measurer = ProgramMeasurerNode::make( - std::move(builder), std::move(runner), Array(), 0); - return CostModel(node); + node->measurer = ProgramMeasurer(std::move(builder), std::move(runner), + Array(), 0); + data_ = std::move(node); } void MeasureModelNode::Update(const Array& inputs, @@ -90,7 +90,7 @@ void MeasureModelNode::Predict(const SearchTask& task, inputs.clear(); inputs.reserve(states.size()); for (const auto& state : states) { - inputs.push_back(MeasureInputNode::make(task, state)); + inputs.push_back(MeasureInput(task, state)); } measurer->SilentMeasure(task, inputs, &results); @@ -101,14 +101,14 @@ void MeasureModelNode::Predict(const SearchTask& task, } } -CostModel PythonBasedModelNode::make(PackedFunc update_func, - PackedFunc predict_func, - PackedFunc predict_stage_func) { +PythonBasedModel::PythonBasedModel(PackedFunc update_func, + PackedFunc predict_func, + PackedFunc predict_stage_func) { auto node = make_object(); node->update_func = std::move(update_func); node->predict_func = std::move(predict_func); node->predict_stage_func = std::move(predict_stage_func); - return CostModel(node); + data_ = std::move(node); } void PythonBasedModelNode::Update(const Array& inputs, @@ -124,9 +124,8 @@ void PythonBasedModelNode::Predict(const SearchTask& task, static_cast(scores->data())); } -void PythonBasedModelNode::PredictStages( - const SearchTask& task, const std::vector& states, - std::vector* state_scores, +void PythonBasedModelNode::PredictStages(const SearchTask& task, + const std::vector& states, std::vector* state_scores, std::vector>* stage_scores) { int n_states = states.size(); int n_stages = task->compute_dag.GetInitState()->stages.size(); @@ -185,14 +184,14 @@ void PythonBasedModelNode::PredictStages( } TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { - return RandomModelNode::make(); + return RandomModel(); }); TVM_REGISTER_GLOBAL("ansor.PythonBasedModel") .set_body_typed([](PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func) { - return PythonBasedModelNode::make(update_func, predict_func, - predict_stage_func); + return PythonBasedModel(update_func, predict_func, + predict_stage_func); }); } // namespace ansor diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index 472a3c201068..f38624a3572c 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -36,20 +36,20 @@ namespace ansor { using runtime::PackedFunc; -class CostModel; - /*! \brief The base class for cost model */ class CostModelNode: public Object { public: // Update the cost model according to new measurement pairs - virtual void Update(const Array& inputs, const Array& results) = 0; + virtual void Update(const Array& inputs, + const Array& results) = 0; // Predict the scores of states virtual void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) = 0; // Predict the scores of all stages in states - virtual void PredictStages(const SearchTask& task, const std::vector& states, + virtual void PredictStages(const SearchTask& task, + const std::vector& states, std::vector* state_scores, std::vector>* stage_scores) { LOG(FATAL) << "Not Implemented"; @@ -65,9 +65,8 @@ class RandomModelNode: public CostModelNode { public: const PackedFunc* random_number_func; - static CostModel make(); - - void Update(const Array& inputs, const Array& results) final; + void Update(const Array& inputs, + const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; @@ -75,14 +74,31 @@ class RandomModelNode: public CostModelNode { TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); }; +/*! + * \brief Managed reference to RandomModelNode. + * \sa RandomModelNode + */ +class RandomModel : public CostModel { + public: + RandomModel(); + explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) + : CostModel(n) {} + + RandomModelNode* operator->() const { + return static_cast(data_.get()); + } + + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel); + using ContainerType = RandomModelNode; +}; + /*! \brief The cost model returns actual cost by measurement */ class MeasureModelNode : public CostModelNode { public: ProgramMeasurer measurer; - static CostModel make(Builder builder, Runner runner); - - void Update(const Array& inputs, const Array& results) final; + void Update(const Array& inputs, + const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; @@ -90,6 +106,18 @@ class MeasureModelNode : public CostModelNode { TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); }; +/*! + * \brief Managed reference to MeasureModelNode. + * \sa MeasureModelNode + */ +class MeasureModel : public CostModel { + public: + MeasureModel(Builder builder, Runner runner); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureModel, CostModel, + MeasureModelNode); +}; + /*! \brief A wrapper for cost model defined by python code * This class will call python's function */ class PythonBasedModelNode: public CostModelNode { @@ -98,10 +126,8 @@ class PythonBasedModelNode: public CostModelNode { PackedFunc predict_func; PackedFunc predict_stage_func; - static CostModel make(PackedFunc update_func, PackedFunc predict_func, - PackedFunc predict_stage_func); - - void Update(const Array& inputs, const Array& results) final; + void Update(const Array& inputs, + const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; void PredictStages(const SearchTask& task, const std::vector& states, @@ -112,6 +138,19 @@ class PythonBasedModelNode: public CostModelNode { TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); }; +/*! + * \brief Managed reference to PythonBasedModelNode. + * \sa PythonBasedModelNode + */ +class PythonBasedModel : public CostModel { + public: + PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, + PythonBasedModelNode); +}; + } // namespace ansor } // namespace tvm diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 3b5849e22262..73f6bad0d432 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -1297,7 +1297,7 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, std::vector min_costs; // read from file - LogReader reader = LogReaderNode::make(filename); + LogReader reader = LogReader(filename); auto cur_inp = make_object(); auto cur_res = make_object(); while (reader->ReadNext(cur_inp.get(), cur_res.get())) { @@ -1310,11 +1310,9 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, auto find_res = task_cache.find(key); if (find_res == task_cache.end()) { // rebuild task - task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), - workload_key, - cur_inp->task->target, - cur_inp->task->target_host, - cur_inp->task->hardware_params); + task = SearchTask(ComputeDAG(workload_key), workload_key, + cur_inp->task->target, cur_inp->task->target_host, + cur_inp->task->hardware_params); task_id = task_cache.size(); // compute min cost for each task @@ -1378,11 +1376,9 @@ void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, task = inputs[i]->task; } else { // the measure input is incomplete // rebuild task for incomplete measure pairs read from file - task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), - workload_key, - inputs[i]->task->target, - inputs[i]->task->target_host, - inputs[i]->task->hardware_params); + task = SearchTask(ComputeDAG(workload_key), workload_key, + inputs[i]->task->target, inputs[i]->task->target_host, + inputs[i]->task->hardware_params); } task_id = task_cache.size(); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 23e005503873..ef4c4632e9bf 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -37,10 +37,10 @@ TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); // Maker for other classes -Iterator IteratorNode::make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters, - std::string attr) { +Iterator::Iterator(std::string name, Range range, IteratorType iter_type, + IteratorAnnotation annotation, + const std::vector* ori_iters, + std::string attr) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); @@ -50,23 +50,22 @@ Iterator IteratorNode::make(std::string name, Range range, node->ori_iters = *ori_iters; } node->attr = std::move(attr); - return Iterator(node); + data_ = std::move(node); } - -Stage StageNode::make(te::Operation op) { +Stage::Stage(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { node->op_type = kCompute; auto* pop = op.as(); for (const auto& axis : pop->axis) { - node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kSpace, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), + axis->dom, kSpace, kNone)); } for (const auto& axis : pop->reduce_axis) { - node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kReduce, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), + axis->dom, kReduce, kNone)); } } else if (op->IsInstance()) { node->op_type = kPlaceholder; @@ -78,67 +77,53 @@ Stage StageNode::make(te::Operation op) { node->op = std::move(op); node->attrs.auto_unroll_max_step = 0; node->attrs.storage_offset = 0; - return Stage(node); + data_ = std::move(node); } -Stage StageNode::make(te::Operation op, StageType op_type, - const std::vector& iters, - ComputeAtType compute_at, StageAttributes attrs) { +Stage::Stage(te::Operation op, StageType op_type, + const std::vector& iters, ComputeAtType compute_at, + StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = iters; node->compute_at = compute_at; node->attrs = attrs; - return Stage(node); + data_ = std::move(node); } -Stage StageNode::make(te::Operation op, StageType op_type, - std::vector&& iters, ComputeAtType compute_at, - StageAttributes attrs) { +Stage::Stage(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = std::move(iters); node->compute_at = compute_at; node->attrs = attrs; - return Stage(node); -} - -State StateNode::make_empty_state() { - auto node = make_object(); - node->attach_map = AttachMapNode::make(); - node->complete = false; - node->aux_info = ObjectRef(); - return State(node); + data_ = std::move(node); } -State StateNode::make(const Array& ops) { +State::State(const Array& ops) { auto node = make_object(); for (const auto& op : ops) { - node->stages.push_back(StageNode::make(op)); + node->stages.push_back(Stage(op)); } - node->attach_map = AttachMapNode::make(); + node->attach_map = AttachMap(make_object()); node->complete = true; node->aux_info = ObjectRef(); - return State(node); + data_ = std::move(node); } -State StateNode::make(const std::vector& stages, - const std::vector& transform_steps, bool complete, - ObjectRef aux_info) { +State::State(const std::vector& stages, + const std::vector& transform_steps, bool complete, + ObjectRef aux_info) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; - node->attach_map = AttachMapNode::make(); + node->attach_map = AttachMap(make_object()); node->complete = complete; node->aux_info = std::move(aux_info); - return State(node); -} - -AttachMap AttachMapNode::make() { - auto node = make_object(); - return AttachMap(node); + data_ = std::move(node); } // Schedule primitives api @@ -149,7 +134,7 @@ void State::reorder(int stage_id, const std::vector& order) { "should be specified"; std::vector after_ids; GetIndices(stage->iters, order, &after_ids); - ReorderStep step = ReorderStepNode::make(stage_id, after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); DoReorderStep(step); } @@ -160,9 +145,9 @@ std::vector State::split(int stage_id, const Iterator& it, const Stage& stage = operator->()->stages[stage_id]; SplitStep step = - SplitStepNode::make(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), - lengths, inner_to_outer); + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), + lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return DoSplitStep(step); } @@ -171,7 +156,7 @@ std::vector State::follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStepNode::make( + FollowSplitStep step = FollowSplitStep( stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return DoFollowSplitStep(step); @@ -183,8 +168,8 @@ std::vector State::follow_fused_split( const Stage& stage = operator->()->stages[stage_id]; FollowFusedSplitStep step = - FollowFusedSplitStepNode::make(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return DoFollowFusedSplitStep(step); } @@ -193,14 +178,14 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { const Stage& stage = operator->()->stages[stage_id]; std::vector indices; GetIndices(stage->iters, iters, &indices); - FuseStep step = FuseStepNode::make(stage_id, indices); + FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); return DoFuseStep(step); } Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make( + AnnotationStep step = AnnotationStep( stage_id, GetIndex(stage->iters, it), kVectorize); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); @@ -209,7 +194,7 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; AnnotationStep step = - AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kParallel); + AnnotationStep(stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -217,7 +202,7 @@ Iterator State::parallel(int stage_id, const Iterator& it) { Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; AnnotationStep step = - AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kUnroll); + AnnotationStep(stage_id, GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -235,20 +220,20 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStepNode::make( + ComputeAtStep step = ComputeAtStep( stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); return DoComputeAtStep(step); } void State::compute_root(int stage_id) { - ComputeRootStep step = ComputeRootStepNode::make(stage_id); + ComputeRootStep step = ComputeRootStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); return DoComputeRootStep(step); } void State::compute_inline(int stage_id) { - ComputeInlineStep step = ComputeInlineStepNode::make(stage_id); + ComputeInlineStep step = ComputeInlineStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); return DoComputeInlineStep(step); } @@ -257,10 +242,10 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; if (thread_type < kVThread || thread_type > kThreadY) { - LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " - << "kThreadY"; + LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " + << "kThreadX, kThreadY"; } - AnnotationStep step = AnnotationStepNode::make( + AnnotationStep step = AnnotationStep( stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); @@ -270,14 +255,14 @@ int State::cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { CacheReadStep step = - CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); + CacheReadStep(stage_id, scope_name, reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } int State::cache_write(int stage_id, const std::string& scope_name, const ComputeDAG& task_dag) { - CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); + CacheWriteStep step = CacheWriteStep(stage_id, scope_name); CopyOnWrite()->transform_steps.push_back(step); return DoCacheWriteStep(step, task_dag); } @@ -286,7 +271,7 @@ void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; PragmaStep step = - PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); + PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); CopyOnWrite()->transform_steps.push_back(step); return DoPragmaStep(step); } @@ -294,8 +279,8 @@ void State::pragma(int stage_id, const Iterator& it, int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; - RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), - factor_iter_id); + RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), + factor_iter_id); CopyOnWrite()->transform_steps.push_back(step); return DoRfactorStep(step, task_dag); } @@ -303,7 +288,7 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStepNode::make( + StorageAlignStep step = StorageAlignStep( stage_id, GetIndex(stage->iters, it), factor, offset); CopyOnWrite()->transform_steps.push_back(step); return DoStorageAlignStep(step); @@ -312,7 +297,7 @@ void State::storage_align(int stage_id, const Iterator& it, int factor, Iterator State::tensorize(int stage_id, const Iterator& it, std::string ti_func_name) { const Stage& stage = operator->()->stages[stage_id]; - TensorizeStep step = TensorizeStepNode::make( + TensorizeStep step = TensorizeStep( stage_id, GetIndex(stage->iters, it), ti_func_name); CopyOnWrite()->transform_steps.push_back(step); return DoTensorizeStep(step); @@ -328,7 +313,7 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make( + pstate->stages[step->stage_id] = Stage( stage->op, stage->op_type, std::move(iters), stage->compute_at, stage->attrs); } @@ -362,12 +347,12 @@ std::vector State::DoSplitStepCommon( } Iterator res; if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { - res = IteratorNode::make(name, Range::make_by_min_extent(tosplit_min, l), - it->iter_type, kNone); + res = Iterator(name, Range::make_by_min_extent(tosplit_min, l), + it->iter_type, kNone); tosplit_min = 0; tosplit_extent = indexdiv(tosplit_extent + l - 1, l); } else { - res = IteratorNode::make(name, Range(), it->iter_type, kNone); + res = Iterator(name, Range(), it->iter_type, kNone); tosplit_min = tosplit_extent = PrimExpr(); } outs.push_back(std::move(res)); @@ -379,12 +364,12 @@ std::vector State::DoSplitStepCommon( } if (inner_to_outer) { outs.push_back( - IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); + Iterator(it->name + ".0", range, it->iter_type, kNone)); std::reverse(outs.begin(), outs.end()); } else { outs.push_back( - IteratorNode::make(it->name + "." + std::to_string(lengths.size()), - range, it->iter_type, kNone)); + Iterator(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); } std::vector new_iters; @@ -395,7 +380,7 @@ std::vector State::DoSplitStepCommon( stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make( + pstate->stages[stage_id] = Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); @@ -479,7 +464,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { range = Range::make_by_min_extent(0, new_extent); } Iterator new_it = - IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); + Iterator(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); @@ -489,7 +474,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make( + pstate->stages[stage_id] = Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); @@ -518,9 +503,9 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator it = stage->iters[step->iter_id]; CHECK_EQ(it->annotation, IteratorAnnotation::kNone); - Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters, - it->attr); + Iterator new_it = Iterator(it->name, it->range, it->iter_type, + step->annotation, &it->ori_iters, + it->attr); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -547,15 +532,14 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { // We do this to keep the AnnotateCPU pass to annotate more efficiently. new_iters.push_back(it); } else { - new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, - it->attr)); + new_iters.push_back(Iterator(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters, it->attr)); } } StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = - StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, + Stage(stage->op, stage->op_type, std::move(new_iters), kIter, stage->attrs); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); @@ -569,16 +553,15 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { - new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, - it->attr)); + new_iters.push_back(Iterator(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters, it->attr)); } // update attach map StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = - StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, - stage->attrs); + pstate->stages[step->stage_id] = Stage(stage->op, stage->op_type, + std::move(new_iters), kRoot, + stage->attrs); pstate->attach_map.DeleteStage(step->stage_id); } @@ -647,7 +630,7 @@ int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { operator->()->task_dag->ops[step->stage_id]; pstate->stages.insert( pstate->stages.begin() + step->stage_id + 1, - StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); + Stage(operator->()->task_dag->ops[step->stage_id + 1])); for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } @@ -667,9 +650,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { } } - int last_dag_op_size = pstate->task_dag.defined() - ? pstate->task_dag->ops.size() - : dag->ops.size(); + int last_dag_op_size = pstate->task_dag.defined() ? + pstate->task_dag->ops.size() : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; CHECK_GE(added_ops, 1); @@ -679,9 +661,9 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // Should insert new stage, update target stage, update the later stage's op pstate->stages.insert( pstate->stages.begin() + step->stage_id, - StageNode::make(operator->()->task_dag->ops[step->stage_id])); + Stage(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = - StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage(operator->()->task_dag->ops[step->stage_id + 1]); int next_stage_id = step->stage_id + 2; // Notice: added_ops should actually assert to be 1 // branch of 2 here is somehow a hack to TVM's cache_write bug with @@ -691,7 +673,7 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { if (added_ops == 2) { pstate->stages.insert( pstate->stages.begin() + next_stage_id, - StageNode::make(operator->()->task_dag->ops[next_stage_id])); + Stage(operator->()->task_dag->ops[next_stage_id])); next_stage_id++; } else if (added_ops > 2) { LOG(ERROR) << "Unexpected behavior of CacheWrite."; @@ -737,10 +719,10 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { // Should insert new stage, update target stage, update the later stage's op pstate->stages.insert( pstate->stages.begin() + step->stage_id, - StageNode::make(operator->()->task_dag->ops[step->stage_id])); + Stage(operator->()->task_dag->ops[step->stage_id])); // maintain the compute_at type of target stage Stage target_stage = - StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage(operator->()->task_dag->ops[step->stage_id + 1]); target_stage.CopyOnWrite()->compute_at = compute_at_type; pstate->stages[step->stage_id + 1] = target_stage; @@ -762,7 +744,7 @@ void State::DoStorageAlignStep(const StorageAlignStep& step) { Iterator State::DoTensorizeStep(const TensorizeStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Iterator it = stage->iters[step->iter_id]; - Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + Iterator new_it = Iterator(it->name, it->range, it->iter_type, IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; @@ -1017,7 +999,7 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { } AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMapNode::make(); + AttachMap map = AttachMap(make_object()); auto pmap = map.CopyOnWrite(); for (const auto& x : operator->()->stage_to_attach_iter) { auto key = x.first; diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 31ed5274184d..2d64db11fc18 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -93,11 +93,6 @@ class IteratorNode : public Object { std::vector ori_iters; // The original iterators before fusion std::string attr; - static Iterator make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr, - std::string attr = ""); - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); @@ -107,19 +102,21 @@ class IteratorNode : public Object { static constexpr const char *_type_key = "ansor.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(Iterator, ObjectRef, IteratorNode); -// Forward decelerations -class Stage; class State; -class AttachMap; +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + Iterator(std::string name, Range range, IteratorType iter_type, + IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr, + std::string attr = ""); -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class TensorizeStep; + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IteratorNode); +}; /*! \brief Stage-level attributes */ struct StageAttributes { @@ -143,23 +140,34 @@ class StageNode : public Object { v->Visit("op", &op); } - static Stage make(te::Operation op); - static Stage make(te::Operation op, StageType op_type, - const std::vector& iters, - ComputeAtType compute_at, StageAttributes attrs); - static Stage make(te::Operation op, StageType op_type, - std::vector&& iters, - ComputeAtType compute_at, StageAttributes attrs); - static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(Stage, ObjectRef, StageNode); + +/*! + * \brief Managed reference to StageNode. + * \sa StageNode + */ +class Stage : public ObjectRef { + public: + explicit Stage(te::Operation op); + Stage(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, StageAttributes attrs); + Stage(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, StageAttributes attrs); + + TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); +}; /*! \brief stores the compute_at relation between stages * This stores a bi-directional mapping from stages and iter: - * 1. Stage to its attached iterator 2. Iterator to the stage attached to it - */ + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + * + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations */ class AttachMapNode: public Object { public: using StageKey = int; @@ -168,18 +176,14 @@ class AttachMapNode: public Object { std::unordered_map stage_to_attach_iter; std::unordered_map> iter_to_attached_stages; - static AttachMap make(); - static constexpr const char* _type_key = "ansor.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); }; -/*! \brief stores the compute_at relation between stages - * This stores a bi-directional mapping from stages and iter: - * 1. Stage to its attached iterator 2. Iterator to the stage attached to it - * - * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages - * to query the relations */ +/*! + * \brief Managed reference to AttachMapNode. + * \sa AttachMapNode + */ class AttachMap : public ObjectRef { public: using StageKey = int; @@ -214,7 +218,17 @@ class StepNode: public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); -/*! \brief The loop state and corresponding history steps to reach this state */ +// Step forward decelerations +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; +class TensorizeStep; + +/*! \brief A state in the search process. + * It consists of the current loop structure and the history steps to reach this state. */ class StateNode: public Object { public: std::vector stages; // Current stages and loop structures @@ -232,22 +246,29 @@ class StateNode: public Object { v->Visit("task_dag", &task_dag); } - static State make_empty_state(); - static State make(const Array& ops); - static State make(const std::vector& stages, - const std::vector& transform_steps, bool complete, - ObjectRef aux_info); - static constexpr const char* _type_key = "ansor.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); }; -/*! \brief A state in the search process. - * It consists of the current loop structure and the history steps to reach this state. */ +/*! + * \brief Managed reference to StateNode. + * \sa StateNode + */ class State : public ObjectRef { public: + explicit State(const Array& ops); + State(const std::vector& stages, + const std::vector& transform_steps, bool complete, + ObjectRef aux_info); + // Schedule primitives void reorder(int stage_id, const std::vector& order); + void compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter); + void compute_root(int stage_id); + void compute_inline(int stage_id); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); std::vector split(int stage_id, const Iterator& it, const std::vector& lengths, bool inner_to_outer = true); @@ -264,12 +285,6 @@ class State : public ObjectRef { IteratorAnnotation thread_type); Iterator tensorize(int stage_id, const Iterator& it, std::string ti_func_name); - void compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter); - void compute_root(int stage_id); - void compute_inline(int stage_id); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); @@ -283,17 +298,17 @@ class State : public ObjectRef { * We separate these functions out, * so you can call them for replay easily given history steps */ void DoReorderStep(const ReorderStep& step); + void DoComputeAtStep(const ComputeAtStep& step); + void DoComputeRootStep(const ComputeRootStep& step); + void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPragmaStep(const PragmaStep& step); + void DoStorageAlignStep(const StorageAlignStep& step); std::vector DoSplitStep(const SplitStep& step); std::vector DoFollowSplitStep(const FollowSplitStep& step); std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); Iterator DoAnnotationStep(const AnnotationStep& step); Iterator DoTensorizeStep(const TensorizeStep& step); - void DoComputeAtStep(const ComputeAtStep& step); - void DoComputeRootStep(const ComputeRootStep& step); - void DoComputeInlineStep(const ComputeInlineStep& step); - void DoPragmaStep(const PragmaStep& step); - void DoStorageAlignStep(const StorageAlignStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 474ea048ebad..4ae35fb410a9 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -58,11 +58,11 @@ const char* ErrorNoToStr[] = { }; // Measure input and result -MeasureInput MeasureInputNode::make(SearchTask task, State state) { +MeasureInput::MeasureInput(SearchTask task, State state) { auto node = make_object(); node->task = std::move(task); node->state = std::move(state); - return MeasureInput(node); + data_ = std::move(node); } MeasureInput MeasureInputNode::copy() const { @@ -72,28 +72,28 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult BuildResultNode::make(std::string filename, Array args, - int error_no, std::string error_msg, - double time_cost) { +BuildResult::BuildResult(std::string filename, Array args, + int error_no, std::string error_msg, + double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); node->error_no = error_no; node->error_msg = std::move(error_msg); node->time_cost = time_cost; - return BuildResult(node); + data_ = std::move(node); } -MeasureResult MeasureResultNode::make(Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { +MeasureResult::MeasureResult(Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; node->error_msg = std::move(error_msg); node->all_cost = all_cost; node->timestamp = timestamp; - return MeasureResult(node); + data_ = std::move(node); } MeasureResult MeasureResultNode::copy() const { @@ -107,13 +107,13 @@ MeasureResult MeasureResultNode::copy() const { } // LocalBuilder -Builder LocalBuilderNode::make(int timeout, int n_parallel, - const std::string& build_func) { +LocalBuilder::LocalBuilder(int timeout, int n_parallel, + const std::string& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; node->build_func = build_func; - return Builder(node); + data_ = std::move(node); } Array LocalBuilderNode::Build(const Array& inputs, @@ -129,10 +129,9 @@ Array LocalBuilderNode::Build(const Array& inputs, } // RPC Runner -Runner RPCRunnerNode::make(const std::string& key, const std::string& host, - int port, int priority, int timeout, int n_parallel, - int number, int repeat, int min_repeat_ms, - double cooldown_interval) { +RPCRunner::RPCRunner(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { auto node = make_object(); node->key = key; node->host = host; @@ -144,7 +143,7 @@ Runner RPCRunnerNode::make(const std::string& key, const std::string& host, node->repeat = repeat; node->min_repeat_ms = min_repeat_ms; node->cooldown_interval = cooldown_interval; - return Runner(node); + data_ = std::move(node); } Array RPCRunnerNode::Run(const Array& inputs, @@ -162,15 +161,15 @@ Array RPCRunnerNode::Run(const Array& inputs, } // Local Runner -Runner LocalRunnerNode::make(int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval) { +LocalRunner::LocalRunner(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { ObjectPtr node = make_object(); node->timeout = timeout; node->number = number; node->repeat = repeat; node->min_repeat_ms = min_repeat_ms; node->cooldown_interval = cooldown_interval; - return Runner(node); + data_ = std::move(node); } Array LocalRunnerNode::Run( @@ -188,19 +187,17 @@ Array LocalRunnerNode::Run( } // Program Measurer -ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, - Array callbacks, - int verbose, - int max_continous_error) { +ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, + Array callbacks, int verbose, + int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); node->runner = std::move(runner); node->callbacks = std::move(callbacks); node->verbose = verbose; - node->max_continous_error = max_continous_error < 0 - ? DEFAULT_MAX_CONTINOUS_ERROR - : max_continous_error; - return ProgramMeasurer(node); + node->max_continous_error = max_continous_error < 0 ? + ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + data_ = std::move(node); } void ProgramMeasurerNode::Reset() { @@ -346,13 +343,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed(MeasureInputNode::make); +.set_body_typed([](SearchTask task, State state) { + return MeasureInput(task, state); +}); TVM_REGISTER_GLOBAL("ansor.BuildResult") -.set_body_typed(BuildResultNode::make); +.set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResult(filename, args, error_no, error_msg, time_cost); +}); TVM_REGISTER_GLOBAL("ansor.MeasureResult") -.set_body_typed(MeasureResultNode::make); +.set_body_typed([](Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp) { + return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); +}); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") .set_body_typed([](const Builder& builder, @@ -367,16 +372,31 @@ TVM_REGISTER_GLOBAL("ansor.RunnerRun") }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") -.set_body_typed(LocalBuilderNode::make); +.set_body_typed([](int timeout, int n_parallel, const std::string& build_func) { + return LocalBuilder(timeout, n_parallel, build_func); +}); TVM_REGISTER_GLOBAL("ansor.LocalRunner") -.set_body_typed(LocalRunnerNode::make); +.set_body_typed([](int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { + return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); +}); TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed(RPCRunnerNode::make); +.set_body_typed([](const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval){ + return RPCRunner(key, host, port, priority, timeout, n_parallel, number, + repeat, min_repeat_ms, cooldown_interval); +}); TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") -.set_body_typed(ProgramMeasurerNode::make); +.set_body_typed([](Builder builder, Runner runner, + Array callbacks, int verbose, + int max_continous_error = -1) { + return ProgramMeasurer(builder, runner, callbacks, verbose, + max_continous_error); +}); } // namespace ansor diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 6e432ba9c88b..a6db55f6181e 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -56,7 +56,7 @@ extern const char *ErrorNoToStr[]; // Inputs and results of one measurement -/* \brief Store the input of a measurement */ +/*! \brief Store the input of a measurement */ class MeasureInputNode: public Object { public: SearchTask task; // The search task @@ -67,20 +67,30 @@ class MeasureInputNode: public Object { v->Visit("state", &state); } - static MeasureInput make(SearchTask task, State state); MeasureInput copy() const; // Do deep copy static constexpr const char* _type_key = "ansor.MeasureInput"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); }; -TVM_DEFINE_OBJECT_REF(MeasureInput, MeasureInputNode); -/* \brief Store the input of a build */ +/*! + * \brief Managed reference to MeasureInputNode. + * \sa MeasureInputNode + */ +class MeasureInput : public ObjectRef { + public: + MeasureInput(SearchTask task, State state); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode); +}; + +/*! \brief Store the input of a build */ class BuildResultNode: public Object { public: std::string filename; // The filename of built binary file Array args; // The arguments - int error_no; // The error code (see MeasureErrorNO). 0 means no error. + int error_no; // The error code (see MeasureErrorNO). + // 0 means no error. std::string error_msg; // The error message if there is any error double time_cost; // The time cost of build @@ -92,19 +102,27 @@ class BuildResultNode: public Object { v->Visit("time_cost", &time_cost); } - static BuildResult make(std::string filename, Array args, - int error_no, std::string error_msg, double time_cost); - static constexpr const char* _type_key = "ansor.BuildResult"; TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); }; -TVM_DEFINE_OBJECT_REF(BuildResult, BuildResultNode); -/* \brief Store the results of a measurement */ +/*! + * \brief Managed reference to BuildResultNode. + * \sa BuildResultNode + */ +class BuildResult : public ObjectRef { + public: + BuildResult(std::string filename, Array args, + int error_no, std::string error_msg, double time_cost); + TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); +}; + +/*! \brief Store the results of a measurement */ class MeasureResultNode: public Object { public: Array costs; // The time costs of execution - int error_no; // The error code (see MeasureErrorNO). 0 means no error. + int error_no; // The error code (see MeasureErrorNO). + // 0 means no error. std::string error_msg; // The error message if there is any error double all_cost; // The time cost of build and run double timestamp; // The time stamps of this measurement @@ -119,16 +137,23 @@ class MeasureResultNode: public Object { MeasureResult copy() const; // Do deep copy - static MeasureResult make(Array costs, int error_no, std::string error_msg, - double all_cost, double timestamp); - static constexpr const char* _type_key = "ansor.MeasureResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); }; -TVM_DEFINE_OBJECT_REF(MeasureResult, MeasureResultNode); +/*! + * \brief Managed reference to MeasureResultNode. + * \sa MeasureResultNode + */ +class MeasureResult : public ObjectRef { + public: + MeasureResult(Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode); +}; -/* \brief Bass class of measurement callbacks */ +/*! \brief Bass class of measurement callbacks */ class MeasureCallbackNode: public Object { public: /*! \biref Callback function that will be called on measurement input/result pairs @@ -141,10 +166,8 @@ class MeasureCallbackNode: public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(MeasureCallback, MeasureCallbackNode); - // Base class for builder and runner - -/* \brief Builder that builds the programs */ +/*! \brief Builder that builds the programs */ class BuilderNode: public Object { public: int n_parallel; // The number of tasks to run in parallel @@ -158,7 +181,7 @@ class BuilderNode: public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(Builder, BuilderNode); -/* \brief Runner that runs the built programs and measure the time cost */ +/*! \brief Runner that runs the built programs and measure the time cost */ class RunnerNode: public Object { public: int timeout; // Timeout of a run @@ -175,20 +198,30 @@ TVM_DEFINE_MUTABLE_OBJECT_REF(Runner, RunnerNode); // Implementation of various builders and runners -/* \brief LocalBuilder use local CPU cores to build programs in parallel */ +/*! \brief LocalBuilder use local CPU cores to build programs in parallel */ class LocalBuilderNode: public BuilderNode { public: std::string build_func; // Build function - static Builder make(int timeout, int n_parallel, const std::string& build_func); - Array Build(const Array& inputs, int verbose) final; static constexpr const char* _type_key = "ansor.LocalBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); }; -/* \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices */ +/*! + * \brief Managed reference to LocalBuilderNode. + * \sa LocalBuilderNode + */ +class LocalBuilder: public Builder { + public: + LocalBuilder(int timeout, int n_parallel, const std::string& build_func); + + TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); +}; + +/*! \brief RPCRunner that uses RPC call to measures the time cost of programs + * on remote devices */ class RPCRunnerNode : public RunnerNode { public: std::string key; @@ -201,10 +234,6 @@ class RPCRunnerNode : public RunnerNode { int min_repeat_ms; double cooldown_interval; - static Runner make(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval); - /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, @@ -214,7 +243,20 @@ class RPCRunnerNode : public RunnerNode { TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); }; -/* \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +/*! + * \brief Managed reference to RPCRunnerNode. + * \sa RPCRunnerNode + */ +class RPCRunner : public Runner { + public: + RPCRunner(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, Runner, RPCRunnerNode); +}; + +/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: int number; @@ -222,9 +264,6 @@ class LocalRunnerNode: public RunnerNode { int min_repeat_ms; double cooldown_interval; - static Runner make(int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval); - /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, @@ -234,6 +273,18 @@ class LocalRunnerNode: public RunnerNode { TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); }; +/*! + * \brief Managed reference to LocalRunnerNode. + * \sa LocalRunnerNode + */ +class LocalRunner: public Runner { + public: + LocalRunner(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, Runner, + LocalRunnerNode); +}; /*! * \brief Measurer that measures the time costs of tvm programs @@ -254,11 +305,6 @@ class ProgramMeasurerNode: public Object { int verbose; int max_continous_error; - static ProgramMeasurer make(Builder builder, Runner runner, - Array callbacks, - int verbose, - int max_continous_error = -1); - /*! \brief Reset book keeping variables */ void Reset(); @@ -277,8 +323,19 @@ class ProgramMeasurerNode: public Object { static constexpr const char* _type_key = "ansor.ProgramMeasurer"; TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); }; -TVM_DEFINE_MUTABLE_OBJECT_REF(ProgramMeasurer, ProgramMeasurerNode); +/*! + * \brief Managed reference to ProgramMeasurerNode. + * \sa ProgramMeasurerNode + */ +class ProgramMeasurer : public ObjectRef { + public: + ProgramMeasurer(Builder builder, Runner runner, + Array callbacks, + int verbose, int max_continous_error = -1); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index c9bccfdce806..51a48780813a 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -33,7 +33,7 @@ TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { - LogReader reader = LogReaderNode::make(log_file); + LogReader reader = LogReader(log_file); const auto& res = reader->ReadLines(-1); size_t log_size = res.first.size(); CHECK_EQ(log_size, res.second.size()); @@ -84,10 +84,10 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } } -SearchCallback PreloadMeasuredStatesNode::make(std::string filename) { +PreloadMeasuredStates::PreloadMeasuredStates(std::string filename) { auto node = make_object(); node->filename = std::move(filename); - return SearchCallback(node); + data_ = std::move(node); } void PreloadMeasuredStatesNode::callback(SearchPolicyNode* policy) { @@ -121,7 +121,7 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") TVM_REGISTER_GLOBAL("ansor.PreloadMeasuredStates") .set_body_typed([](std::string filename) { - return PreloadMeasuredStatesNode::make(filename); + return PreloadMeasuredStates(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 4710cc05ae7f..03e7c3f025df 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -36,10 +36,9 @@ namespace tvm { namespace ansor { -class SearchPolicy; class SearchPolicyNode; -/*! Callback function to be called before or after the search process */ +/*! \brief Callback function to be called before or after the search process */ class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; @@ -55,14 +54,24 @@ class PreloadMeasuredStatesNode : public SearchCallbackNode { public: std::string filename; - static SearchCallback make(std::string filename); - void callback(SearchPolicyNode* policy) final; static constexpr const char *_type_key = "ansor.PreloadMeasuredStates"; TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); }; +/*! + * \brief Managed reference to PreloadMeasuredStatesNode. + * \sa PreloadMeasuredStatesNode + */ +class PreloadMeasuredStates : public SearchCallback { + public: + explicit PreloadMeasuredStates(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback, + PreloadMeasuredStatesNode); +}; + /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc index 7e4c3999dce3..5b2c10c08c81 100644 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -49,20 +49,19 @@ TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); // All possible candidates for auto_unroll const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; -SearchPolicy SketchSearchPolicyNode::make(CostModel program_cost_model, - Map params, - int seed) { +SketchSearchPolicy::SketchSearchPolicy(CostModel program_cost_model, + Map params, + int seed) { auto node = make_object(); node->program_cost_model = std::move(program_cost_model); node->rand_gen_ = std::mt19937(seed); node->params = std::move(params); - return SearchPolicy(node); + data_ = std::move(node); } State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) { + int early_stopping, int num_measure_per_iter, int verbose, + ProgramMeasurer measurer, Array pre_search_callbacks) { std::vector best_states, random_states; this->cur_task = task; this->verbose = verbose; @@ -221,7 +220,7 @@ void SketchSearchPolicyNode::PickStatesWithEpsGreedy( if (measured_states_set_.count(state_str)) { continue; } measured_states_set_.insert(state_str); - inputs->push_back(MeasureInputNode::make(cur_task, *pstate)); + inputs->push_back(MeasureInput(cur_task, *pstate)); measured_states_vector_.push_back(std::move(*pstate)); } } @@ -701,8 +700,8 @@ void SketchSearchPolicyNode::GenerateSketch( auto step = pstate->transform_steps[split_step_id].as(); CHECK(step != nullptr); pstate->transform_steps[split_step_id] - = SplitStepNode::make(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, - step->inner_to_outer); + = SplitStep(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, + step->inner_to_outer); } } } @@ -733,7 +732,7 @@ int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, policy->cur_task->hardware_params->max_innermost_split_factor); StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps[step_id] = SplitStepNode::make( + pstate->transform_steps[step_id] = SplitStep( ps->stage_id, ps->iter_id, ps->extent, candidate_lens[(*rand_gen)() % candidate_lens.size()], ps->inner_to_outer); @@ -1508,12 +1507,12 @@ class RuleCustomSketch : public SketchGenerationRule { PackedFunc apply_func_; }; -SearchCallback PreloadCustomSketchRuleNode::make(PackedFunc meet_condition_func, - PackedFunc apply_func) { +PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, + PackedFunc apply_func) { auto node = make_object(); node->meet_condition_func = meet_condition_func; node->apply_func = apply_func; - return SearchCallback(node); + data_ = std::move(node); } void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { @@ -1525,15 +1524,14 @@ void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { } TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") -.set_body_typed([](CostModel program_cost_model, - Map params, +.set_body_typed([](CostModel program_cost_model, Map params, int seed){ - return SketchSearchPolicyNode::make(program_cost_model, params, seed); + return SketchSearchPolicy(program_cost_model, params, seed); }); TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreloadCustomSketchRuleNode::make(meet_condition_func, apply_func); + return PreloadCustomSketchRule(meet_condition_func, apply_func); }); } // namespace ansor diff --git a/src/ansor/search_policy/sketch_search_policy.h b/src/ansor/search_policy/sketch_search_policy.h index 60920c5c1fdd..54a5cdd1fa4e 100644 --- a/src/ansor/search_policy/sketch_search_policy.h +++ b/src/ansor/search_policy/sketch_search_policy.h @@ -51,7 +51,8 @@ class SketchSearchPolicyNode: public SearchPolicyNode { public: /*! \brief The cost model for complete programs */ CostModel program_cost_model; - + /*! \brief Random generator */ + std::mt19937 rand_gen_; /*! \brief The parameters for search. It stores the following parameters: * int evolutionary_search_population // The population size for evolutionary search * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search @@ -63,14 +64,9 @@ class SketchSearchPolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; - /*! \brief The rules to generate sketches */ std::vector sketch_rules; - static SearchPolicy make(CostModel program_cost_model, - Map params, - int seed); - /*! \brief Search and make n_trails measurements. * \returns the best state */ State Search(SearchTask task, int n_trials, @@ -92,7 +88,8 @@ class SketchSearchPolicyNode: public SearchPolicyNode { /*! \brief Pick states from best states and random states with eps-greedy policy */ void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, - const std::vector& random_states, int remaining_n_trials); + const std::vector& random_states, + int remaining_n_trials); private: // Run one round of the search pipeline @@ -111,10 +108,22 @@ class SketchSearchPolicyNode: public SearchPolicyNode { int num_best_states, std::vector* best_states); SplitFactorizationMemo split_memo_; // Memorize split space for Split - std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration }; -TVM_DEFINE_MUTABLE_OBJECT_REF(SketchSearchPolicy, SketchSearchPolicyNode); + +/*! + * \brief Managed reference to SketchSearchPolicyNode. + * \sa SketchSearchPolicyNode + */ +class SketchSearchPolicy : public SearchPolicy { + public: + SketchSearchPolicy(CostModel program_cost_model, + Map params, + int seed); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchSearchPolicy, SearchPolicy, + SketchSearchPolicyNode); +}; /*! \brief Pre-search callback function to load custom rules for sketch generation */ class PreloadCustomSketchRuleNode : public SearchCallbackNode { @@ -123,15 +132,25 @@ class PreloadCustomSketchRuleNode : public SearchCallbackNode { PackedFunc meet_condition_func; PackedFunc apply_func; - static SearchCallback make(PackedFunc meet_condition_func, - PackedFunc apply_func); - void callback(SearchPolicyNode* policy) final; static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); }; +/*! + * \brief Managed reference to PreloadCustomSketchRuleNode. + * \sa PreloadCustomSketchRuleNode + */ +class PreloadCustomSketchRule : public SearchCallback { + public: + PreloadCustomSketchRule(PackedFunc meet_condition_func, + PackedFunc apply_func); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, + PreloadCustomSketchRuleNode); +}; + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index ba42ca55611c..412d0afcca98 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -371,7 +371,7 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split auto pstate = tmp_s.CopyOnWrite(); pstate->transform_steps[step_id] = - SplitStepNode::make(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); + SplitStep(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); return tmp_s; } @@ -401,7 +401,7 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen auto val = std::to_string(auto_unroll_configs[(*random_gen)() % auto_unroll_configs.size()]); auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = PragmaStepNode::make( + pstate->transform_steps[step_id] = PragmaStep( ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val); return tmp_s; } diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index c65516150f30..17ab73efb6aa 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -35,28 +35,27 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(HardwareParamsNode); TVM_REGISTER_NODE_TYPE(SearchTaskNode); -HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, - int max_unroll_vec, - int max_innermost_split_factor) { +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { auto node = make_object(); node->num_cores = num_cores; node->vector_unit_bytes = vector_unit_bytes; node->cache_line_bytes = cache_line_bytes; node->max_unroll_vec = max_unroll_vec; node->max_innermost_split_factor = max_innermost_split_factor; - return HardwareParams(node); + data_ = std::move(node); } HardwareParams HardwareParamsNode::GetDefaultHardwareParams( const Target& target, const Target& target_host) { if (target->target_name == "llvm") { - return HardwareParamsNode::make(tvm::runtime::threading::MaxConcurrency(), - 32, 64, 16, 64); + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), + 32, 64, 16, 64); } else if (target->device_type == kDLGPU) { // TODO(jcf94): temp implementation, max vectorize size in GPU is related // to the data type - auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); auto* p_hardware_params = hardware_params.CopyOnWrite(); auto ctx = TVMContext{kDLGPU, 0}; @@ -87,7 +86,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return hardware_params; } else if (target->device_type == kDLOpenCL) { // TODO(jcf94): temp implementation - auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); auto p_hardware_params = hardware_params.CopyOnWrite(); auto ctx = TVMContext{kDLOpenCL, 0}; @@ -118,10 +117,9 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return HardwareParams(); } -SearchTask SearchTaskNode::make(ComputeDAG compute_dag, - std::string workload_key, Target target, - Target target_host, - HardwareParams hardware_params) { +SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -133,24 +131,23 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams( node->target, node->target_host); } - return SearchTask(node); + data_ = std::move(node); } TVM_REGISTER_GLOBAL("ansor.HardwareParams") .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, int max_innermost_split_factor) { - return HardwareParamsNode::make(num_cores, vector_unit_bytes, - cache_line_bytes, max_unroll_vec, - max_innermost_split_factor); + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor); }); TVM_REGISTER_GLOBAL("ansor.SearchTask") .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, Target target, Target target_host, HardwareParams hardware_params) { - return SearchTaskNode::make(compute_dag, workload_key, target, - target_host, hardware_params); + return SearchTask(compute_dag, workload_key, target, target_host, + hardware_params); }); } // namespace ansor diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index cfa5500c39f4..c53fdcd0f792 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -32,7 +32,7 @@ namespace tvm { namespace ansor { -class HardwareParams; class SearchTask; +class HardwareParams; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { @@ -69,17 +69,25 @@ class HardwareParamsNode : public Object { v->Visit("warp_size", &warp_size); } - static HardwareParams make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor); - static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(HardwareParams, ObjectRef, HardwareParamsNode); + +/*! + * \brief Managed reference to HardwareParamsNode. + * \sa HardwareParamsNode + */ +class HardwareParams : public ObjectRef { + public: + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor); + + TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); +}; /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { @@ -98,14 +106,23 @@ class SearchTaskNode : public Object { v->Visit("hardware_params", &hardware_params); } - static SearchTask make(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, - HardwareParams hardware_params); - static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(SearchTask, ObjectRef, SearchTaskNode); + +/*! + * \brief Managed reference to SearchTaskNode. + * \sa SearchTaskNode + */ +class SearchTask : public ObjectRef { + public: + SearchTask(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params); + + TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SearchTaskNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 2d8379f56a5f..71fba764506f 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -199,7 +199,7 @@ struct Handler > { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + data->push_back(::tvm::ansor::ReorderStep(stage_id, int_list)); } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -211,7 +211,7 @@ struct Handler > { reader->Read(&int_list); s = reader->NextArrayItem(); CHECK(s); reader->Read(&inner_to_outer); - data->push_back(::tvm::ansor::SplitStepNode::make( + data->push_back(::tvm::ansor::SplitStep( stage_id, iter_id, extent, std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); @@ -224,7 +224,7 @@ struct Handler > { reader->Read(&src_step_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&n_split); - data->push_back(::tvm::ansor::FollowSplitStepNode::make( + data->push_back(::tvm::ansor::FollowSplitStep( stage_id, iter_id, src_step_id, n_split)); } else if (name == "FFSP") { s = reader->NextArrayItem(); CHECK(s); @@ -237,14 +237,14 @@ struct Handler > { reader->Read(&level); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_or_nparts); - data->push_back(::tvm::ansor::FollowFusedSplitStepNode::make( + data->push_back(::tvm::ansor::FollowFusedSplitStep( stage_id, iter_id, int_list, level, factor_or_nparts)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::FuseStepNode::make(stage_id, int_list)); + data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); } else if (name == "AN") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -252,7 +252,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, + data->push_back(::tvm::ansor::AnnotationStep(stage_id, iter_id, ::tvm::ansor::IteratorAnnotation(ann))); } else if (name == "CA") { s = reader->NextArrayItem(); CHECK(s); @@ -261,16 +261,16 @@ struct Handler > { reader->Read(&target_stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); - data->push_back(::tvm::ansor::ComputeAtStepNode::make( + data->push_back(::tvm::ansor::ComputeAtStep( stage_id, target_stage_id, iter_id)); } else if (name == "CR") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeRootStepNode::make(stage_id)); + data->push_back(::tvm::ansor::ComputeRootStep(stage_id)); } else if (name == "CI") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeInlineStepNode::make(stage_id)); + data->push_back(::tvm::ansor::ComputeInlineStep(stage_id)); } else if (name == "CHR") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -278,14 +278,14 @@ struct Handler > { reader->Read(&scope_name); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::CacheReadStepNode::make( + data->push_back(::tvm::ansor::CacheReadStep( stage_id, scope_name, int_list)); } else if (name == "CHW") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&scope_name); - data->push_back(::tvm::ansor::CacheWriteStepNode::make( + data->push_back(::tvm::ansor::CacheWriteStep( stage_id, scope_name)); } else if (name == "PR") { s = reader->NextArrayItem(); CHECK(s); @@ -294,7 +294,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&pragma_type); - data->push_back(::tvm::ansor::PragmaStepNode::make( + data->push_back(::tvm::ansor::PragmaStep( stage_id, iter_id, pragma_type)); } else if (name == "RF") { s = reader->NextArrayItem(); CHECK(s); @@ -303,7 +303,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStepNode::make( + data->push_back(::tvm::ansor::RfactorStep( stage_id, iter_id, factor_iter_id)); } else if (name == "SA") { s = reader->NextArrayItem(); CHECK(s); @@ -314,7 +314,7 @@ struct Handler > { reader->Read(&factor); s = reader->NextArrayItem(); CHECK(s); reader->Read(&offset); - data->push_back(::tvm::ansor::StorageAlignStepNode::make( + data->push_back(::tvm::ansor::StorageAlignStep( stage_id, iter_id, factor, offset)); } else if (name == "TS") { s = reader->NextArrayItem(); CHECK(s); @@ -323,7 +323,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ti_func_name); - data->push_back(::tvm::ansor::TensorizeStepNode::make( + data->push_back(::tvm::ansor::TensorizeStep( stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; @@ -457,10 +457,10 @@ TVM_REGISTER_OBJECT_TYPE(LogReaderNode); const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) -MeasureCallback LogToFileNode::make(std::string filename) { +LogToFile::LogToFile(std::string filename) { auto node = make_object(); node->filename = std::move(filename); - return MeasureCallback(node); + data_ = std::move(node); } void WriteMeasureRecords(std::ostream* os, @@ -506,11 +506,11 @@ void LogToFileNode::callback(const SearchPolicy& policy, WriteMeasureRecords(&ofs, inputs, results); } -LogReader LogReaderNode::make(std::string filename) { +LogReader::LogReader(std::string filename) { auto node = make_object(); node->filename = filename; node->infile.open(filename, std::ifstream::in); - return LogReader(node); + data_ = std::move(node); } LogReaderNode::~LogReaderNode() { @@ -556,15 +556,15 @@ std::pair, Array > LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target) { +std::pair BestMeasurePairInFile( + const std::string& filename, const std::string& workload_key, + const Target& target) { std::pair best_pair; double best_cost = 1e30; auto inp = make_object(); auto res = make_object(); - LogReader reader = LogReaderNode::make(filename); + LogReader reader = LogReader(filename); while (reader->ReadNext(inp.get(), res.get())) { if (res->error_no != kNoError || inp->task->workload_key != workload_key @@ -594,12 +594,12 @@ TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") TVM_REGISTER_GLOBAL("ansor.LogToFile") .set_body_typed([](const std::string& filename) { - return LogToFileNode::make(filename); + return LogToFile(filename); }); TVM_REGISTER_GLOBAL("ansor.LogReader") .set_body_typed([](const std::string& filename) { - return LogReaderNode::make(filename); + return LogReader(filename); }); TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") @@ -648,8 +648,8 @@ TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") ptask = inp->task.operator->(); } else { // the measure input is incomplete // rebuild task for incomplete measure pairs read from file - SearchTask new_task = SearchTaskNode::make( - ComputeDAGNode::make_by_workload_key(workload_key), + SearchTask new_task = SearchTask( + ComputeDAG(workload_key), workload_key, inp->task->target, inp->task->target_host, diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index d877717db9cb..82dd036991e6 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -38,8 +38,6 @@ class LogToFileNode : public MeasureCallbackNode { public: std::string filename; - static MeasureCallback make(std::string filename); - /*! \brief Log measure pairs to file. This is called by the search policy */ void callback(const SearchPolicy& policy, const Array& inputs, @@ -49,15 +47,23 @@ class LogToFileNode : public MeasureCallbackNode { TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; -class LogReader; +/*! + * \brief Managed reference to LogToFileNode. + * \sa LogToFileNode + */ +class LogToFile : public MeasureCallback { + public: + explicit LogToFile(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); +}; -/*! \brief Log reader */ +/*! \brief Log reader to load step logs from a target file.*/ class LogReaderNode : public Object { public: std::string filename; std::ifstream infile; - static LogReader make(std::string filename); ~LogReaderNode(); /*! \brief Read next line in the log file @@ -76,7 +82,17 @@ class LogReaderNode : public Object { private: std::string cur_line; }; -TVM_DEFINE_MUTABLE_OBJECT_REF(LogReader, LogReaderNode); + +/*! + * \brief Managed reference to LogReaderNode. + * \sa LogReaderNode + */ +class LogReader : public ObjectRef { + public: + explicit LogReader(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogReader, ObjectRef, LogReaderNode); +}; /*! \brief Write measure records to an output stream */ void WriteMeasureRecords(std::ostream* os, diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 857f3e570de0..bd0a7f7165f6 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -34,11 +34,11 @@ namespace tvm { namespace ansor { /********** Reorder **********/ -ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { +ReorderStep::ReorderStep(int stage_id, const std::vector& after_ids) { auto node = make_object(); node->stage_id = stage_id; node->after_ids = after_ids; - return ReorderStep(node); + data_ = std::move(node); } void ReorderStepNode::ApplyToSchedule(std::vector *stages, @@ -155,9 +155,9 @@ std::string PrintSplitAsPythonAPI(std::vector *stages, return ss.str(); } -SplitStep SplitStepNode::make(int stage_id, int iter_id, - PrimExpr extent, const std::vector& lengths, - bool inner_to_outer) { +SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer) { auto node = make_object(); node->stage_id = stage_id; // Extent can be a unreducible expression in some special cases @@ -167,7 +167,7 @@ SplitStep SplitStepNode::make(int stage_id, int iter_id, node->iter_id = iter_id; node->lengths = lengths; node->inner_to_outer = inner_to_outer; - return SplitStep(node); + data_ = std::move(node); } std::vector SplitStepNode::ApplyToSchedule( @@ -184,18 +184,19 @@ std::string SplitStepNode::PrintAsPythonAPI( } /********** Follow Split **********/ -FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, - int src_step_id, int n_split) { +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, + int src_step_id, int n_split) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->src_step_id = src_step_id; node->n_split = n_split; - return FollowSplitStep(node); + data_ = std::move(node); } -void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const { +void FollowSplitStepNode::ExtractSplitLengths( + const std::vector& transform_steps, + std::vector* lengths) const { CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); @@ -237,15 +238,15 @@ std::string FollowSplitStepNode::PrintAsPythonAPI( } /********** Follow Fused Split **********/ -FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { +FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->src_step_ids = src_step_ids;; node->level = level; node->factor_or_nparts = factor_or_nparts; - return FollowFusedSplitStep(node); + data_ = std::move(node); } PrimExpr FollowFusedSplitStepNode::ExtractSplitLength( @@ -279,16 +280,16 @@ std::string FollowFusedSplitStepNode::PrintAsPythonAPI( te::Schedule *schedule, const std::vector& transform_steps) const { const PrimExpr& length = ExtractSplitLength(transform_steps); return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); + {length}, factor_or_nparts); } /********** Fuse **********/ -FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { +FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { auto node = make_object(); node->stage_id = stage_id; node->fused_ids = fused_ids; - return FuseStep(node); + data_ = std::move(node); } IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, @@ -306,7 +307,7 @@ IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); new_axes.push_back(fused_axis); new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, - axes.end()); + axes.end()); (*stage_to_axes)[stage] = std::move(new_axes); return fused_axis; @@ -337,12 +338,13 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Annotation **********/ -AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { +AnnotationStep::AnnotationStep(int stage_id, int iter_id, + IteratorAnnotation ann) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->annotation = ann; - return AnnotationStep(node); + data_ = std::move(node); } void AnnotationStepNode::ApplyToSchedule(std::vector *stages, @@ -426,12 +428,13 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Compute At **********/ -ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, + int target_iter_id) { auto node = make_object(); node->stage_id = stage_id; node->target_stage_id = target_stage_id; node->target_iter_id = target_iter_id; - return ComputeAtStep(node); + data_ = std::move(node); } void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, @@ -460,10 +463,10 @@ std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Compute Root **********/ -ComputeRootStep ComputeRootStepNode::make(int stage_id) { +ComputeRootStep::ComputeRootStep(int stage_id) { auto node = make_object(); node->stage_id = stage_id; - return ComputeRootStep(node); + data_ = std::move(node); } void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, @@ -485,10 +488,10 @@ std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages } /********** Compute Inline **********/ -ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { +ComputeInlineStep::ComputeInlineStep(int stage_id) { auto node = make_object(); node->stage_id = stage_id; - return ComputeInlineStep(node); + data_ = std::move(node); } void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, @@ -511,13 +514,13 @@ std::string ComputeInlineStepNode::PrintAsPythonAPI( } /********** Cache Read **********/ -CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, - const std::vector& reader_stage_ids) { +CacheReadStep::CacheReadStep(int stage_id, std::string scope_name, + const std::vector& reader_stage_ids) { auto node = make_object(); node->stage_id = stage_id; node->scope_name = std::move(scope_name); node->reader_stage_ids = reader_stage_ids; - return CacheReadStep(node); + data_ = std::move(node); } te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, @@ -574,11 +577,11 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Cache Write **********/ -CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { +CacheWriteStep::CacheWriteStep(int stage_id, std::string scope_name) { auto node = make_object(); node->stage_id = stage_id; node->scope_name = std::move(scope_name); - return CacheWriteStep(node); + data_ = std::move(node); } Array CacheWriteStepNode::ApplyToSchedule( @@ -642,13 +645,12 @@ std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Pragma **********/ -PragmaStep PragmaStepNode::make(int stage_id, int iter_id, - std::string pragma_type) { +PragmaStep::PragmaStep(int stage_id, int iter_id, std::string pragma_type) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->pragma_type = std::move(pragma_type); - return PragmaStep(node); + data_ = std::move(node); } void PragmaStepNode::ApplyToSchedule(std::vector *stages, @@ -692,12 +694,12 @@ std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Rfactor **********/ -RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { +RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->factor_iter_id = factor_iter_id; - return RfactorStep(node); + data_ = std::move(node); } Array RfactorStepNode::ApplyToSchedule(std::vector *stages, @@ -719,9 +721,9 @@ Array RfactorStepNode::ApplyToSchedule(std::vector *stage } std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -772,14 +774,14 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Storage Align **********/ -StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, - int factor, int offset) { +StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, + int factor, int offset) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->factor = factor; node->offset = offset; - return StorageAlignStep(node); + data_ = std::move(node); } void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, @@ -803,13 +805,13 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( } /********** Tensorize **********/ -TensorizeStep TensorizeStepNode::make(int stage_id, int iter_id, - std::string ti_func_name) { +TensorizeStep::TensorizeStep(int stage_id, int iter_id, + std::string ti_func_name) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->ti_func_name = ti_func_name; - return TensorizeStep(node); + data_ = std::move(node); } void TensorizeStepNode::ApplyToSchedule(std::vector *stages, diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 9af14429bf61..3eb023eb75c8 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -21,15 +21,15 @@ * \file ansor/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. * - * \Note How to add a new transform step. + * \note How to add a new transform step. * Take fuse for example: - * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its make function - * `FuseStepNode::make(...)` in `transform_steps.cc` + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction + * function `FuseStep::FuseStep(...)` in `transform_steps.cc` * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. * - In these two functions you need to lower this step with tvm's te schedule API * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with - * CopyOnWrite style + * CopyOnWrite style * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` * in `serialization.cc` @@ -56,8 +56,6 @@ class ReorderStepNode: public StepNode { std::vector after_ids; // The iterator ids after reorder. // This array should specify the order of all iterators. - static ReorderStep make(int stage_id, const std::vector& after_ids); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -69,7 +67,18 @@ class ReorderStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ReorderStep, Step, ReorderStepNode); + +/*! + * \brief Managed reference to ReorderStepNode. + * \sa ReorderStepNode + */ +class ReorderStep : public Step { + public: + ReorderStep(int stage_id, const std::vector& after_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ReorderStepNode); +}; /*! \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors */ @@ -81,10 +90,6 @@ class SplitStepNode: public StepNode { bool inner_to_outer; // If true, the `lengths` denote the lengths of // iterators from inner level to outer level - static SplitStep make(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, - bool inner_to_outer); - std::vector ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -96,7 +101,20 @@ class SplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(SplitStep, Step, SplitStepNode); + +/*! + * \brief Managed reference to SplitStepNode. + * \sa SplitStepNode + */ +class SplitStep : public Step { + public: + SplitStep(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitStepNode); +}; /*! \brief Similar to SplitStepNode, but use split factor from another step * (i.e. Follow another split step) */ @@ -106,9 +124,6 @@ class FollowSplitStepNode: public StepNode { int src_step_id; // The index of the split step to follow in the history int n_split; // The number of split level - static FollowSplitStep make(int stage_id, int iter_id, - int src_step_id, int n_split); - void ExtractSplitLengths(const std::vector& transform_steps, std::vector* lengths) const; @@ -124,7 +139,19 @@ class FollowSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(FollowSplitStep, Step, FollowSplitStepNode); + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowSplitStepNode); +}; + /*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. * \Note This can be used for the split in cooperative fetching @@ -136,10 +163,6 @@ class FollowFusedSplitStepNode: public StepNode { int level; // Use the length in this split level bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - static FollowFusedSplitStep make(int stage_id, int iter_id, - const std::vector& src_step_ids, - int level, bool factor_or_nparts); - PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; std::vector ApplyToSchedule(std::vector *stages, @@ -154,15 +177,26 @@ class FollowFusedSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + FollowFusedSplitStep(int stage_id, int iter_id, + const std::vector& src_step_ids, + int level, bool factor_or_nparts); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowFusedSplitStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: std::vector fused_ids; // The ids of iterators to fuse - static FuseStep make(int stage_id, const std::vector& fused_ids); - IterVar ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -174,7 +208,18 @@ class FuseStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(FuseStep, Step, FuseStepNode); + +/*! + * \brief Managed reference to FuseStepNode. + * \sa FuseStepNode + */ +class FuseStep : public Step { + public: + FuseStep(int stage_id, const std::vector& fused_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FuseStepNode); +}; /*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) @@ -184,8 +229,6 @@ class AnnotationStepNode: public StepNode { int iter_id; IteratorAnnotation annotation; - static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -197,7 +240,18 @@ class AnnotationStepNode: public StepNode { static constexpr const char* _type_key = "ansor.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(AnnotationStep, Step, AnnotationStepNode); + +/*! + * \brief Managed reference to AnnotationStepNode. + * \sa AnnotationStepNode + */ +class AnnotationStep : public Step { + public: + AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); + + TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AnnotationStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode: public StepNode { @@ -205,9 +259,6 @@ class ComputeAtStepNode: public StepNode { int target_stage_id; int target_iter_id; - static ComputeAtStep make(int stage_id, int target_stage_id, - int target_iter_id); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -219,12 +270,22 @@ class ComputeAtStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ComputeAtStep, Step, ComputeAtStepNode); + +/*! + * \brief Managed reference to ComputeAtStepNode. + * \sa ComputeAtStepNode + */ +class ComputeAtStep : public Step { + public: + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeAtStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: - static ComputeRootStep make(int stage_id); void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -237,13 +298,22 @@ class ComputeRootStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ComputeRootStep, Step, ComputeRootStepNode); + +/*! + * \brief Managed reference to ComputeRootStepNode. + * \sa ComputeRootStepNode + */ +class ComputeRootStep : public Step { + public: + explicit ComputeRootStep(int stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeRootStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode: public StepNode { public: - static ComputeInlineStep make(int stage_id); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -255,7 +325,18 @@ class ComputeInlineStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ComputeInlineStep, Step, ComputeInlineStepNode); + +/*! + * \brief Managed reference to ComputeInlineStepNode. + * \sa ComputeInlineStepNode + */ +class ComputeInlineStep : public Step { + public: + explicit ComputeInlineStep(int stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeInlineStepNode); +}; /*! \brief Cache read step that corresponds to te::Schedule::cache_read */ class CacheReadStepNode: public StepNode { @@ -263,11 +344,9 @@ class CacheReadStepNode: public StepNode { std::string scope_name; std::vector reader_stage_ids; - static CacheReadStep make(int stage_id, std::string scope_name, - const std::vector& reader_stage_id); - te::Tensor ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -277,7 +356,19 @@ class CacheReadStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheReadStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(CacheReadStep, Step, CacheReadStepNode); + +/*! + * \brief Managed reference to CacheReadStepNode. + * \sa CacheReadStepNode + */ +class CacheReadStep : public Step { + public: + CacheReadStep(int stage_id, std::string scope_name, + const std::vector& reader_stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheReadStepNode); +}; /*! \brief Cache read step that corresponds to te::Schedule::cache_write * \Note This step will cache_write all output tensors of target stage */ @@ -285,10 +376,9 @@ class CacheWriteStepNode: public StepNode { public: std::string scope_name; - static CacheWriteStep make(int stage_id, std::string scope_name); - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -298,7 +388,18 @@ class CacheWriteStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheWriteStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(CacheWriteStep, Step, CacheWriteStepNode); + +/*! + * \brief Managed reference to CacheWriteStepNode. + * \sa CacheWriteStepNode + */ +class CacheWriteStep : public Step { + public: + CacheWriteStep(int stage_id, std::string scope_name); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheWriteStepNode); +}; /*! \brief Cache read step that corresponds to te::Schedule::pragma */ class PragmaStepNode: public StepNode { @@ -306,8 +407,6 @@ class PragmaStepNode: public StepNode { int iter_id; std::string pragma_type; - static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -319,7 +418,18 @@ class PragmaStepNode: public StepNode { static constexpr const char* _type_key = "ansor.PragmaStep"; TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(PragmaStep, Step, PragmaStepNode); + +/*! + * \brief Managed reference to PragmaStepNode. + * \sa PragmaStepNode + */ +class PragmaStep : public Step { + public: + PragmaStep(int stage_id, int iter_id, std::string pragma_type); + + TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PragmaStepNode); +}; /*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ class RfactorStepNode: public StepNode { @@ -327,11 +437,9 @@ class RfactorStepNode: public StepNode { int iter_id; int factor_iter_id; - static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -341,7 +449,18 @@ class RfactorStepNode: public StepNode { static constexpr const char* _type_key = "ansor.RfactorStep"; TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(RfactorStep, Step, RfactorStepNode); + +/*! + * \brief Managed reference to RfactorStepNode. + * \sa RfactorStepNode + */ +class RfactorStep : public Step { + public: + RfactorStep(int stage_id, int iter_id, int factor_iter_id); + + TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RfactorStepNode); +}; /*! \brief Storage align step that corresponds to te::Schedule::storage_align */ class StorageAlignStepNode: public StepNode { @@ -350,9 +469,6 @@ class StorageAlignStepNode: public StepNode { int factor; int offset; - static StorageAlignStep make(int stage_id, int iter_id, int factor, - int offset); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -364,7 +480,18 @@ class StorageAlignStepNode: public StepNode { static constexpr const char* _type_key = "ansor.StorageAlignStep"; TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); + +/*! + * \brief Managed reference to StorageAlignStepNode. + * \sa StorageAlignStepNode + */ +class StorageAlignStep : public Step { + public: + StorageAlignStep(int stage_id, int iter_id, int factor, int offset); + + TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StorageAlignStepNode); +}; /*! \brief Tensorize step that corresponds to te::Schedule::tensorize * \Note This step takes a global registered function name as input. */ @@ -373,9 +500,6 @@ class TensorizeStepNode: public StepNode { int iter_id; std::string ti_func_name; - static TensorizeStep make(int stage_id, int iter_id, - std::string ti_func_name); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -387,7 +511,18 @@ class TensorizeStepNode: public StepNode { static constexpr const char* _type_key = "ansor.TensorizeStep"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(TensorizeStep, Step, TensorizeStepNode); + +/*! + * \brief Managed reference to TensorizeStepNode. + * \sa TensorizeStepNode + */ +class TensorizeStep : public Step { + public: + TensorizeStep(int stage_id, int iter_id, std::string ti_func_name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorizeStep, Step, TensorizeStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TensorizeStepNode); +}; } // namespace ansor } // namespace tvm diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 5f1dea0f1ea5..36ac46f49551 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -79,7 +79,7 @@ using namespace tvm::ansor; // Test Access Analyzer TEST(ComputeDAG, GetProducersConsumers) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + const auto& dag = tvm::ansor::ComputeDAG(tensors); int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; From 8e53d125d9fcdeed6ab5c422d151438a422a12a0 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 23 Jun 2020 16:28:58 +0800 Subject: [PATCH 36/78] Some lint fix & Recover the double constructor of tvm::PrimExpr (#39) * lint fix * clang-format-fix * pylint fix * Update * Recover the double constructor of tvm::PrimExpr * Fix pylint * pylint fix * pylint fix --- include/tvm/ir/expr.h | 5 -- python/tvm/ansor/__init__.py | 1 - python/tvm/ansor/auto_schedule.py | 8 +- python/tvm/ansor/cost_model/cost_model.py | 1 - python/tvm/ansor/dispatcher.py | 2 +- python/tvm/ansor/env.py | 1 - python/tvm/ansor/feature.py | 46 +++++------ python/tvm/ansor/loop_state.py | 81 ++++++++++--------- python/tvm/ansor/measure.py | 37 ++++++--- python/tvm/ansor/relay_integration.py | 27 ++++--- python/tvm/ansor/task_scheduler.py | 12 ++- python/tvm/ansor/workload_registry.py | 8 +- python/tvm/relay/backend/compile_engine.py | 2 +- python/tvm/relay/op/strategy/x86.py | 1 - python/tvm/relay/testing/dqn.py | 6 +- python/tvm/relay/testing/resnet.py | 3 +- python/tvm/te/tensor.py | 4 +- scripts/common.py | 17 ++++ scripts/shape_configs.py | 17 ++++ scripts/tune_network.py | 17 ++++ scripts/tune_op_subgraph.py | 17 ++++ scripts/tune_test.py | 17 ++++ src/ansor/measure.cc | 1 - src/ansor/measure.h | 1 - .../search_policy/sketch_search_policy.cc | 10 ++- src/ansor/serialization.cc | 2 +- src/ansor/transform_step.h | 5 +- src/ir/expr.cc | 2 - src/relay/op/tensor/transform.cc | 2 - src/relay/transforms/defuse_ops.cc | 33 +++----- .../transforms/kernel_layout_transform.cc | 25 +++--- .../transforms/kernel_layout_transform.h | 60 ++++++++++---- src/runtime/rpc/rpc_module.cc | 5 +- src/tir/transforms/unroll_loop.cc | 5 +- 34 files changed, 300 insertions(+), 181 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b3e527ca6fd9..b2ce50d91f58 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -112,11 +112,6 @@ class PrimExpr : public BaseExpr { * \param value The value to be constructed. */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! - * \brief construct from double. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(double value); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { return static_cast(get())->dtype; } diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index c629c1049a87..edade490018c 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -21,7 +21,6 @@ from . import measure from . import serialization from . import loop_state -from . import auto_schedule from . import utils from . import feature from . import workload_registry diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index a03d9fdacbc2..4497bb400703 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,7 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel, XGBModel +from .cost_model import RandomModel from . import _ffi_api @@ -133,7 +133,6 @@ def __init__(self, @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): """Callback function before or after search process""" - pass @tvm._ffi.register_object("ansor.PreloadMeasuredStates") @@ -262,8 +261,7 @@ def auto_schedule(workload, target=None, sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( workload, target, target_host, search_policy, hardware_params, tune_option) return sch, tensors - elif isinstance(workload, SearchTask): + if isinstance(workload, SearchTask): sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) return sch, tensors - else: - raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") + raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index 57cc53853b2e..fbfc8242488b 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -27,7 +27,6 @@ @tvm._ffi.register_object("ansor.CostModel") class CostModel(Object): """The base class for cost model""" - pass @tvm._ffi.register_object("ansor.RandomModel") diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 0c07fd141bd2..3a5dc4e9e206 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -34,7 +34,7 @@ class DispatchContext(object): """ Base class of dispatch context. """ - current = None + current = None def __init__(self): self._old_ctx = DispatchContext.current diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 0f35f92acbbc..56e76e26ee4f 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -23,4 +23,3 @@ def __init__(self): self.topi_in_compute_rewrite_mode = False GLOBAL_SCOPE = AutoschedulerGlobalScope() - diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index d9f6d297f1af..fa1b2cb07dcc 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -40,21 +40,20 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar size_of_int = 4 size_of_float = 4 - """ - The format for n records is: - { - int n; - int[n+2] sizes - - float[sizes[0]] feature for record 1 - float[sizes[1]] feature for record 2 - ... feature for record i... - float[sizes[n-1]] feature for record n - - float[sizes[n]] normalized throughput for n records - int[sizes[n+1]] task id for n records - } - """ + # The format for n records is: + # { + # int n; + # int[n+2] sizes + + # float[sizes[0]] feature for record 1 + # float[sizes[1]] feature for record 2 + # ... feature for record i... + # float[sizes[n-1]] feature for record n + + # float[sizes[n]] normalized throughput for n records + # int[sizes[n+1]] task id for n records + # } + vec_len = DEFAULT_FEATURE_VEC_LEN # unpack sizes @@ -70,15 +69,14 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar for size in sizes[:-2]: row = [] - """ - Now we need to unpack the feature for multiple statements. - The format is: - { - int n_stmts - float[n_stmt][vec_len] feature_vecs - } - where vec_len can be calculated by `(size - 1) / n_stmts` - """ + # Now we need to unpack the feature for multiple statements. + # The format is: + # { + # int n_stmts + # float[n_stmt][vec_len] feature_vecs + # } + # where vec_len can be calculated by `(size - 1) / n_stmts` + if size == 0: # failed during lowering features.append(np.zeros((1, vec_len))) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 3c60c3f09a8d..8560a57bc902 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -42,7 +42,6 @@ @tvm._ffi.register_object("ansor.Iterator") class Iterator(Object): """A for loop iterator""" - pass @tvm._ffi.register_object("ansor.Stage") @@ -90,8 +89,7 @@ def __getitem__(self, k): self.stages_cache = _ffi_api.StateGetStages(self.state_object) if isinstance(k, tvm.te.Tensor): return self.stages_cache[self.stage_id_map[k.op]] - else: - raise ValueError("Item must be Tensor") + raise ValueError("Item must be Tensor") def __update_tensor_stage_map(self): if not self.stages_cache: @@ -164,13 +162,13 @@ def reorder(self, stage_id, order): self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) self.clear_cache() - def split(self, stage_id, it, lengths, inner_to_outer=True): + def split(self, stage_id, iterator, lengths, inner_to_outer=True): """ Parameters ---------- stage_id : Int The index of the stage to split - it : Iterator + iterator : Iterator The iterator to split lengths: List[Int] The split factors @@ -188,18 +186,18 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, it, lengths, + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) self.clear_cache() return res - def follow_split(self, stage_id, it, src_step_id, n_split): + def follow_split(self, stage_id, iterator, src_step_id, n_split): """ Parameters ---------- stage_id : Int The index of the stage to split - it : Iterator + iterator : Iterator The iterator to split src_step_id : Int The index of the split step to follow in the history @@ -216,19 +214,19 @@ def follow_split(self, stage_id, it, src_step_id, n_split): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, it, + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, src_step_id, n_split) self.clear_cache() return res - def follow_fused_split(self, stage_id, it, src_step_ids, level, + def follow_fused_split(self, stage_id, iterator, src_step_ids, level, factor_or_nparts): """ Parameters ---------- stage_id : Int The index of the stage to split - it : Iterator + iterator : Iterator The iterator to split src_step_ids : List[Int] The indices of the split steps to follow in the history @@ -248,8 +246,8 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, it, - src_step_ids, level, + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, + iterator, src_step_ids, level, factor_or_nparts) self.clear_cache() return res @@ -277,13 +275,13 @@ def fuse(self, stage_id, iters): self.clear_cache() return res - def vectorize(self, stage_id, it): + def vectorize(self, stage_id, iterator): """ Parameters ---------- stage_id : Int The index of the stage to vectorize - it : Iterator + iterator : Iterator The iterator to be vectorized Returns @@ -296,17 +294,17 @@ def vectorize(self, stage_id, it): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, it) + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) self.clear_cache() return res - def parallel(self, stage_id, it): + def parallel(self, stage_id, iterator): """ Parameters ---------- stage_id : Int The index of the stage to parallel - it : Iterator + iterator : Iterator The iterator to be parallelized Returns @@ -319,17 +317,17 @@ def parallel(self, stage_id, it): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, it) + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) self.clear_cache() return res - def unroll(self, stage_id, it, max_unroll=-1): + def unroll(self, stage_id, iterator, max_unroll=-1): """ Parameters ---------- stage_id : Int The index of the stage to unroll - it : Iterator + iterator : Iterator The iterator to be unrolled max_unroll: Int The maximum length of the iterator that can be unrolled @@ -344,17 +342,18 @@ def unroll(self, stage_id, it, max_unroll=-1): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, max_unroll) + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, + max_unroll) self.clear_cache() return res - def bind_thread(self, stage_id, it, thread_name): + def bind_thread(self, stage_id, iterator, thread_name): """ Parameters ---------- stage_id : Int The index of the stage to bind - it : Iterator + iterator : Iterator The iterator to be bound thread_name : str The name of the thread (e.g. "blockIdx.x", "threadIdx.y", "vthread") @@ -378,7 +377,8 @@ def bind_thread(self, stage_id, it, thread_name): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id) + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, + thread_id) self.clear_cache() return res @@ -403,7 +403,7 @@ def compute_at(self, stage_id, target_stage_id, target_iter): raise ValueError("target_stage_id must be Tensor or Int") self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, - target_stage_id, target_iter) + target_stage_id, target_iter) self.clear_cache() def compute_root(self, stage_id): @@ -494,13 +494,13 @@ def cache_write(self, stage_id, scope_name): scope_name, self.compute_dag) return self.__insert_new_stage(new_stage_id) - def pragma(self, stage_id, it, pragma_type): + def pragma(self, stage_id, iterator, pragma_type): """ Parameters ---------- stage_id : Int The index of the stage to add pragma - it : Iterator + iterator : Iterator The iterator to add pragma pragma_type : Str """ @@ -509,16 +509,17 @@ def pragma(self, stage_id, it, pragma_type): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, it, pragma_type) + self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, + pragma_type) self.clear_cache() - def rfactor(self, stage_id, it, factor_iter_id): + def rfactor(self, stage_id, iterator, factor_iter_id): """ Parameters ---------- stage_id : Int The index of the stage to do reduction factor - it : Iterator + iterator : Iterator factor_iter_id : Int Returns @@ -531,17 +532,18 @@ def rfactor(self, stage_id, it, factor_iter_id): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, it, - factor_iter_id, self.compute_dag) + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, + iterator, factor_iter_id, + self.compute_dag) return self.__insert_new_stage(new_stage_id) - def storage_align(self, stage_id, it, factor, offset): + def storage_align(self, stage_id, iterator, factor, offset): """ Parameters ---------- stage_id : Int The index of the stage to do storage align - it : Iterator + iterator : Iterator factor : Int offset : Int """ @@ -550,10 +552,11 @@ def storage_align(self, stage_id, it, factor, offset): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, + factor, offset) self.clear_cache() - def tensorize(self, stage_id, it, ti_func_name): + def tensorize(self, stage_id, iterator, ti_func_name): """ The `ti_func_name` corresponds to a global registered funcion that returns a TensorIntrin @@ -561,7 +564,7 @@ def tensorize(self, stage_id, it, ti_func_name): ---------- stage_id : Int The index of the stage to do storage align - it : Iterator + iterator : Iterator The target iterator ti_func_name : Str Tensorize intrinsic function name @@ -577,7 +580,7 @@ def tensorize(self, stage_id, it, ti_func_name): raise ValueError("stage_id must be Tensor or Int") self.state_object, res = _ffi_api.StateTensorize(self.state_object, - stage_id, it, + stage_id, iterator, ti_func_name) self.clear_cache() return res diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index f00fe672505d..be7d69e5ed3a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -40,10 +40,11 @@ from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk from . import _ffi_api -from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ + check_remote from .compute_dag import LayoutRewriteLevel -logger = logging.getLogger('ansor') +LOGGER = logging.getLogger('ansor') # The maximum length of error message MAX_ERROR_MSG_LEN = 512 @@ -52,7 +53,7 @@ @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): """Base class for measurement callback function""" - pass + @tvm._ffi.register_object("ansor.MeasureInput") class MeasureInput(Object): @@ -105,6 +106,8 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.Builder") class Builder(Object): + """ Base class of Builder + """ def build(self, measure_inputs, verbose=1): """ Parameters @@ -121,6 +124,8 @@ def build(self, measure_inputs, verbose=1): @tvm._ffi.register_object("ansor.Runner") class Runner(Object): + """ Base class of Runner + """ def run(self, measure_inputs, build_results, verbose=1): """ Parameters @@ -221,7 +226,7 @@ def __init__(self, key, host, port, priority=1, number, repeat, min_repeat_ms, cooldown_interval) if check_remote(key, host, port, priority, timeout): - logger.info("Get devices for measurement successfully!") + LOGGER.info("Get devices for measurement successfully!") else: raise RuntimeError("Cannot get remote devices from the tracker. " "Please check the status of tracker by " @@ -260,7 +265,7 @@ def __init__(self, self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) device_key = '$local$device$%d' % self.tracker.port self.server = Server(host, port=self.tracker.port, port_end=10000, - key=device_key, use_popen=True, silent=True, + key=device_key, use_popen=True, silent=True, tracker_addr=(self.tracker.host, self.tracker.port)) self.runner = RPCRunner(device_key, host, self.tracker.port, priority, n_parallel, timeout, number, repeat, @@ -302,6 +307,8 @@ def make_error_msg(): def local_build_worker(index): + """ Local builder function + """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool measure_inputs, build_func, timeout, verbose = global_build_arguments @@ -362,7 +369,10 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") -def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, build_func: str, verbose: int): +def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, + build_func: str, verbose: int): + """ Local builder build function + """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global global_build_arguments @@ -409,6 +419,8 @@ def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], def rpc_run_worker(index): + """ ... + """ inputs, build_results, key, host, port, priority, timeout, number, \ repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments @@ -417,7 +429,8 @@ def rpc_run_worker(index): build_res = build_results[index] if build_res.error_no != MeasureErrorNo.NO_ERROR: - return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() def timed_func(): tic = time.time() @@ -478,6 +491,8 @@ def timed_func(): def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], timeout: float, number: int, repeat: int, min_repeat_ms: int, cooldown_interval: float, verbose: int): + """ ... + """ MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log def timed_func(inp, build_res): @@ -522,16 +537,16 @@ def timed_func(inp, build_res): "Measure input size should be equal to build results" for inp, build_res in zip(inputs, build_results): if build_res.error_no != 0: - res = ( - MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + res = (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() else: res = call_func_with_timeout( timeout, timed_func, args=(inp, build_res)) if isinstance(res, TimeoutError): if verbose >= 1: print("*T", end="") # Run timeout - res = ( - MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + timeout, time.time() + res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, \ + build_res.time_cost + timeout, time.time() measure_results.append(MeasureResult(*res)) if verbose >= 1: diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 3c2eabd3dfac..f2873f8c72fd 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -25,21 +25,22 @@ import json import threading -from tvm import target, te, transform +import tvm +from tvm import te, transform from tvm.te.tensor import PlaceholderOp, ComputeOp from .dispatcher import DispatchContext from .workload_registry import register_workload_bufs, compute_dag_hash from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE -def call_all_topi_funcs(mod, target, params): +def call_all_topi_funcs(mod, target, params, target_host=None): """Call all TOPI compute + schedule to extract tasks in a relay program""" # pylint: disable=import-outside-toplevel from tvm import relay with transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): bld_mod = relay.build_module.BuildModule() - bld_mod.call_all_topi_funcs(mod, target=target, params=params) + bld_mod.call_all_topi_funcs(mod, target=target, params=params, target_host=target_host) def extract_from_program(mod, params, target, target_host=None): """ Extract tuning tasks from a relay program. @@ -95,7 +96,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None): # wrap build call in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool build_thread = threading.Thread(target=call_all_topi_funcs, - args=(mod, target, param)) + args=(mod, target, param, target_host)) build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() @@ -112,7 +113,8 @@ def extract_from_multiple_program(mods, params, target, target_host=None): def prepare_layout_rewrite(mod, params, target): """ - Prepare for kernel layout rewrite. This function will write layout infos to a global static variable. + Prepare for kernel layout rewrite. This function will write layout infos to a global static + variable. Then these layout info will be used by a relay pass `kernel_layout_transform`. """ # pylint: disable=import-outside-toplevel @@ -207,26 +209,26 @@ def auto_schedule_topi(outs): env = TracingEnvironment.current if env is None: # in the final build mode - state = DispatchContext.current.query(target.Target.current(), key) + state = DispatchContext.current.query(tvm.target.Target.current(), key) if state is None: return te.create_schedule([x.op for x in outs]) dag = ComputeDAG(io_tensors) # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, # Since kernel layout has already been rewritten in relay pass - schedule, _ = dag.apply_steps_from_state(state, - layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) + schedule, _ = dag.apply_steps_from_state( + state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) return schedule - elif env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode + if env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode env.add_workload_key(key) return te.create_schedule([x.op for x in outs]) - elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: + if env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode if has_layout_free: # Rewrite the DAG and update the transform history for # the new dag in DispatchContext dispatch_ctx = DispatchContext.current - tgt = target.Target.current() + tgt = tvm.target.Target.current() state = dispatch_ctx.query(tgt, key) assert state is not None dag = ComputeDAG(outs) @@ -236,5 +238,4 @@ def auto_schedule_topi(outs): if new_key != key: env.layout_rewrite_success_ct += 1 return te.create_schedule([x.op for x in outs]) - else: - raise ValueError("Invalid tracing mode: " + env.tracing_mode) + raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 587fe3121e88..5b916ed39769 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -41,6 +41,8 @@ def compute_score(self, costs: List[float]) -> float: def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], num_measure_per_iter, load_model_file=None, load_log_file=None): + """ ... + """ if search_policy == 'default': search_policy = 'sketch.xgb' @@ -98,7 +100,8 @@ class SimpleTaskScheduler(TaskScheduler): load_log_file: str Load history log file to pre-train cost model eps-random: float - Always allocate this percent of n_trials to select tasks randomly. This is for encouraging exploration. + Always allocate this percent of n_trials to select tasks randomly. + This is for encouraging exploration. verbose: int The level of verbosity. 0 means silent. alpha: float @@ -144,7 +147,8 @@ def __init__(self, self.sequential_now_task_idx = 0 self.sequential_now_task_begin_ct = 0 - def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + def tune(self, tune_option: TuneOption, + search_policy: Union[str, List[SearchPolicy]] = 'default'): """ Tune tasks. Notice: This method does not have return value, make sure to set `LogToFile` @@ -252,6 +256,8 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol self.tune_task(task_idx) def tune_task(self, task_idx): + """ ... + """ if self.use_debug_measurement_simulator is not None: measure_inputs, measure_results = \ self.use_debug_measurement_simulator.get_next_batch( @@ -282,7 +288,7 @@ def tune_task(self, task_idx): if self.verbose >= 1: print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + - "best_costs (ms): %s\ttask_ct: %s") % + "best_costs (ms): %s\ttask_ct: %s") % (self.ct, self.cur_score * 1e3, time.time() - self.tic, to_str_round(self.best_costs * 1e3, decimal=3), self.task_cts)) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index e706c0ec4cf9..025b5f03c661 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -23,7 +23,8 @@ The dag should be the return value of this `func_name(*args)`. Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags -and matching them efficiently is not easy. Therefore, we use the above string to encode a compute dag. +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. These strings are efficient for serialization/matching and wont' be too long. When we need the dag, we decode the string and call the function, which will return the dag. """ @@ -65,6 +66,8 @@ def matmul(N, M, K): def compute_dag_hash(dag: ComputeDAG): + """ Get hash value for a ComputeDAG + """ # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG str_key = '' for op in dag.ops: @@ -139,8 +142,7 @@ def workload_key_to_tensors(workload_key: str) -> List[Tensor]: if callable(lookup): args = deserialize_args(workload[1:]) return lookup(*args) - else: - return lookup + return lookup @ tvm._ffi.register_func("ansor.workload_key_to_dag") diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 66ef5cd4c852..b6bedb411540 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -18,10 +18,10 @@ """Backend code generation engine.""" from __future__ import absolute_import +import os import logging import numpy as np import tvm -import os from tvm import te from tvm.runtime import Object from ... import target as _target diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 2a0ddd1329b5..3453b089f373 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -17,7 +17,6 @@ """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import -import os from tvm.te import SpecializedCondition from tvm import ansor from .generic import * diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index b65e0ad5cae9..3d6883362c9b 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -63,7 +63,8 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" return relay.Function(args, dense2) -def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", + layout="NCHW"): """Get benchmark workload for a Deep Q Network Parameters ---------- @@ -82,5 +83,6 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, layout=layout) + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, + layout=layout) return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 4383157d9f06..ac63afde4cba 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -163,7 +163,8 @@ def resnet(units, num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) - data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, name='bn_data') + data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, + name='bn_data') (_, _, height, _) = data_shape if layout == "NHWC": (_, height, _, _) = data_shape diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6539aabaa48f..6a2120817eb1 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -56,9 +56,9 @@ class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): - ndim = self.ndim + # ndim = self.ndim # After ansor kernel layout rewrite, len(indices) <= ndim, - # and the indices will get modified by Ansor during schedule generation. + # and the indices will get modified by Ansor during schedule generation. # if len(indices) != ndim: # raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) diff --git a/scripts/common.py b/scripts/common.py index 8f4fbec09dd0..ac25b28e55b1 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Common utility for scripts""" import argparse import math diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py index 244638f5b29c..db6b3b9dc9aa 100644 --- a/scripts/shape_configs.py +++ b/scripts/shape_configs.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """ Shape configurations for single operator / subgraph evaluation This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ """ diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 1905d8132003..188da6cbe6e6 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Tune a whole neural network""" import argparse import logging diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py index 6574bb77e510..d3e70501873e 100644 --- a/scripts/tune_op_subgraph.py +++ b/scripts/tune_op_subgraph.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Tune all workloads for single op & subgraph evaluation""" import argparse import logging diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 67c0526dd624..c98da3eca53b 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Use auto scheduler to tune workloads""" import argparse import logging diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 4ae35fb410a9..a044acfe5395 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2020 by Contributors * \file ansor/measure.cc * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ diff --git a/src/ansor/measure.h b/src/ansor/measure.h index a6db55f6181e..760a1542944f 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2020 by Contributors * \file ansor/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc index 5b2c10c08c81..63f75cad1c83 100644 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -901,7 +901,7 @@ int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { - if(GetIntParam(policy->params, "disable_change_compute_location")) { + if (GetIntParam(policy->params, "disable_change_compute_location")) { return 0; } @@ -1063,7 +1063,8 @@ int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, int InitPopulationParallel(const SketchSearchPolicyNode* policy, State* state) { - std::function annotate_parallel; + std::function + annotate_parallel; annotate_parallel = [&annotate_parallel]( const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { @@ -1095,7 +1096,8 @@ int InitPopulationParallel(const SketchSearchPolicyNode* policy, } if (parallel_degree == 1) { - auto res = (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); + auto res = + (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); if (res != (*state)->attach_map->iter_to_attached_stages.end()) { for (int attached_stage_id : res->second) { annotate_parallel(policy, state, attached_stage_id, 0); @@ -1188,7 +1190,7 @@ int InitPopulationVectorization(const SketchSearchPolicyNode* policy, } if (num_fusible > 1) { - num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse + num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse } if (num_fusible == 1) { diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 71fba764506f..c026b9b6251a 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -434,7 +434,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { reader->Read(&tmp); data->costs.clear(); for (const auto& i : tmp) { - data->costs.push_back(i); + data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); } s = reader->NextArrayItem(); CHECK(s); reader->Read(&data->error_no); diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 3eb023eb75c8..edd71732b3e2 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -277,7 +277,7 @@ class ComputeAtStepNode: public StepNode { */ class ComputeAtStep : public Step { public: - ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeAtStepNode); @@ -286,7 +286,6 @@ class ComputeAtStep : public Step { /*! \brief Fuse step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -550,8 +549,8 @@ struct hash<::tvm::ansor::Step> { } else { ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number } - return ret; } + return ret; } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { return ::dmlc::HashCombine(3, ::dmlc::HashCombine(std::hash()(ps->stage_id), diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 6e898dd5ddb4..fd380aa33f86 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -38,8 +38,6 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr::PrimExpr(double value) : PrimExpr(FloatImm(DataType::Float(64), value)) {} - PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; if (auto* ptr = ref.as()) { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 18ace14a0b75..30269b85795f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2461,7 +2461,6 @@ TVM_REGISTER_NODE_TYPE(KernelLayoutTransformAttrs); Array KernelLayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - //const Target& target) { const auto* param = attrs.as(); CHECK(param != nullptr); return Array{ @@ -2473,7 +2472,6 @@ bool KernelLayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - const auto* data = types[0].as(); CHECK(data != nullptr); const KernelLayoutTransformAttrs* params = attrs.as(); diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc index f7c9037df687..1a108fb08888 100644 --- a/src/relay/transforms/defuse_ops.cc +++ b/src/relay/transforms/defuse_ops.cc @@ -17,19 +17,19 @@ * under the License. */ -#include #include -#include -#include #include +#include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include #include "pattern_util.h" @@ -38,14 +38,11 @@ namespace relay { class DefuseOpsMutator : public ExprMutator { public: - class FuncBodyMutator : public ExprMutator { public: Array args_; - FuncBodyMutator(const Array& args) : ExprMutator() { - args_ = args; - } + explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } Expr VisitExpr_(const VarNode* n) { const std::string& name = n->name_hint(); @@ -74,23 +71,19 @@ class DefuseOpsMutator : public ExprMutator { } }; -Expr DeFuseOps(const Expr& expr) { - return DefuseOpsMutator().Mutate(expr); -} +Expr DeFuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); } namespace transform { Pass DeFuseOps() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::DeFuseOps(f)); - }; - return CreateFunctionPass(pass_func, 3, "DeFuseOps", - {"InferType"}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::DeFuseOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "DeFuseOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps") -.set_body_typed(DeFuseOps); +TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps").set_body_typed(DeFuseOps); } // namespace transform diff --git a/src/relay/transforms/kernel_layout_transform.cc b/src/relay/transforms/kernel_layout_transform.cc index 681785c8123c..421968b8a6b9 100644 --- a/src/relay/transforms/kernel_layout_transform.cc +++ b/src/relay/transforms/kernel_layout_transform.cc @@ -17,13 +17,17 @@ * under the License. */ +#include "kernel_layout_transform.h" + +#include #include -#include #include -#include +#include #include + +#include #include -#include "kernel_layout_transform.h" +#include namespace tvm { namespace relay { @@ -36,7 +40,8 @@ Expr KernelLayoutTransform(const Expr& expr) { KernelLayoutVisitor visitor; // Do a pre-order DFS to gather the optimal kernel layouts for all conv2d nodes. - // These layouts were written to global static variables in python function `prepare_layout_rewrite` + // These layouts were written to global static variables in python function + // `prepare_layout_rewrite` visitor.VisitExpr(expr); // Do a post-order DSF to mutate layout for all conv2d nodes @@ -47,15 +52,13 @@ namespace transform { Pass KernelLayoutTransform() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::KernelLayoutTransform(f)); - }; - return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", - {"InferType"}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::KernelLayoutTransform(f)); + }; + return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform") -.set_body_typed(KernelLayoutTransform); +TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform").set_body_typed(KernelLayoutTransform); } // namespace transform diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h index c82a96b30612..c6c38fb71cf4 100644 --- a/src/relay/transforms/kernel_layout_transform.h +++ b/src/relay/transforms/kernel_layout_transform.h @@ -1,11 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ +#define TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ + #include #include + +#include +#include #include #include - -#include "pattern_util.h" +#include #include "../../ansor/compute_dag.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -13,10 +37,11 @@ namespace relay { /*! \brief A visitor to gather the optimal kernel layout for all conv2d nodes. */ class KernelLayoutVisitor : public ExprVisitor { public: - void VisitExpr_(const CallNode *n) { + void VisitExpr_(const CallNode* n) { if (n && n->op.as() && (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end()) && n->args[1]->type_as()->shape[3].as()->value > 1 && + op_white_lists.end()) && + n->args[1]->type_as()->shape[3].as()->value > 1 && !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { ori_layouts_map[n] = global_ori_layouts_queue.front(); new_layouts_map[n] = global_new_layouts_queue.front(); @@ -28,30 +53,31 @@ class KernelLayoutVisitor : public ExprVisitor { ExprVisitor::VisitExpr_(n); } - std::unordered_map ori_layouts_map; - std::unordered_map new_layouts_map; - std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; + std::unordered_map ori_layouts_map; + std::unordered_map new_layouts_map; + std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; static std::deque global_ori_layouts_queue; static std::deque global_new_layouts_queue; }; - /*! \brief A mutator to rewrite kernel layout for all conv2d nodes */ class KernelLayoutTransformer : public ExprMutator { public: - KernelLayoutTransformer(KernelLayoutVisitor* visitor): ExprMutator(), visitor_(visitor) {} + explicit KernelLayoutTransformer(KernelLayoutVisitor* visitor) + : ExprMutator(), visitor_(visitor) {} Expr VisitExpr_(const CallNode* n) { auto new_n = ExprMutator::VisitExpr_(n); const auto* call = new_n.as(); - std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; + std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; if (call && call->op.as() && (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end() && n->args[1]->type_as()->shape[3].as()->value > 1)) { + op_white_lists.end() && + n->args[1]->type_as()->shape[3].as()->value > 1)) { auto ori_layout_iter = visitor_->ori_layouts_map.find(n); auto new_layout_iter = visitor_->new_layouts_map.find(n); if (ori_layout_iter != visitor_->ori_layouts_map.end() && @@ -60,8 +86,7 @@ class KernelLayoutTransformer : public ExprMutator { const std::string& new_layout = new_layout_iter->second; Expr updated_kernel = MakeKernelLayoutTransform(call->args[1], ori_layout, new_layout); Array updated_args = {call->args[0], updated_kernel}; - new_n = Call(call->op, updated_args, - call->attrs); + new_n = Call(call->op, updated_args, call->attrs); } } return new_n; @@ -71,6 +96,7 @@ class KernelLayoutTransformer : public ExprMutator { KernelLayoutVisitor* visitor_; }; +} // namespace relay +} // namespace tvm -} // namespace relay -} // namespace tvm +#endif // TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index b95d5ba25926..d58130d700f4 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -306,8 +306,7 @@ std::shared_ptr RPCModuleGetSession(Module mod) { } inline void CacheFlush(const char* p, unsigned int allocation_size) { -// TODO: (FrozenGene) -// Support ARM. +// TODO(FrozenGene): Support ARM. #if (defined(_M_X64) || defined(__x86_64__)) size_t cache_line = 64; @@ -346,7 +345,7 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repe CHECK_EQ(number, 1); // we want to keep input data for (int j = 1; j < args.size(); j++) { - CacheFlush((char*)(args[j].operator DLTensor*()->data), + CacheFlush(reinterpret_cast(args[j].operator DLTensor*()->data), GetDataSize(*(args[j].operator DLTensor*()))); } } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 3876d67b7b11..4f1078165f34 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -59,7 +59,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); TVM_ATTR_FIELD(explicit_unroll_max_extent) - .describe("The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)") + .describe("The maximum extent of a loop that can be unrolled explicitly (-1 for infinite)") .set_default(32); } }; @@ -170,7 +170,8 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); - if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { + if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && + explicit_unroll_) { // Do not unroll too long loops ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; return For(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); From cd5c5ad71dea12d1dd51b9db913c525329949dcf Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jun 2020 12:09:05 -0700 Subject: [PATCH 37/78] Add MutateComputeLocation and MutateParallel in evolutionary search (#40) * Add MutateComputeLocation and MutateParallel in evolutionary search * fix lint --- src/ansor/auto_schedule.h | 11 +- src/ansor/compute_dag.cc | 66 ---- src/ansor/loop_state.cc | 5 +- src/ansor/loop_state.h | 27 +- src/ansor/measure.cc | 7 +- src/ansor/search_policy/search_policy.cc | 1 - .../search_policy/sketch_search_policy.cc | 9 +- src/ansor/search_policy/utils.cc | 345 +++++++++++++++++- src/ansor/search_policy/utils.h | 21 +- src/ansor/search_task.h | 1 - src/ansor/serialization.cc | 25 +- src/ansor/transform_step.cc | 3 +- src/ansor/transform_step.h | 25 +- src/ansor/utils.h | 18 - 14 files changed, 389 insertions(+), 175 deletions(-) diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index f17c043cfadd..7ffd2c4d3a70 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -37,14 +37,11 @@ namespace ansor { class TuneOptionNode : public Object { public: int n_trials; // Number of total measurement trials - int early_stopping; // Stops early the tuning if no improvement after n - // measurements - int num_measure_per_iter; // The number of programs to be measured at each - // iteration + int early_stopping; // Stops early the tuning if no improvement after n measurements + int num_measure_per_iter; // The number of programs to be measured at each iteration int verbose; // Verbosity level. 0 means silent. Builder builder; // Builder which builds the program - Runner runner; // Runner which runs the program and measure time - // costs + Runner runner; // Runner which runs the program and measure time costs Array measure_callbacks; // MeasureCallback functions Array pre_search_callbacks; // SearchCallback functions // run before search @@ -76,13 +73,13 @@ class TuneOption : public ObjectRef { Array pre_search_callbacks); TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(TuneOptionNode); }; /*! \brief Auto schedule for a compute declaration */ std::pair > AutoSchedule( SearchTask task, SearchPolicy search_policy, TuneOption tune_option); +/*! \brief Auto schedule for a compute declaration */ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, SearchPolicy search_policy, HardwareParams hardware_params, diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 13f64b2bdc89..ee87318cdd84 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -653,63 +653,6 @@ class IndexRewriter : public StmtExprMutator { return GetRef(op); } - /* - PrimExpr Mutate_(const Call* op, const PrimExpr& e) { - PrimExpr op_ = IRMutator::Mutate_(op, e); - - const Call* call = op_.as(); - - if (call->call_type == Call::CallType::Halide) { - te::Tensor t = Downcast(call->func).output(call->value_index); - auto it = placeholder_new_names_.find(t->op); - if (it != placeholder_new_names_.end()) { - const std::vector& new_names = it->second; - const Array& new_shape = placeholder_new_shapes_.at(t->op); - std::unordered_map name_to_arg; - for (const auto& arg : call->args) { - std::string axis_name; - if (const auto* pimm = arg.as()) { - CHECK_EQ(pimm->value, 0); - axis_name = "IntImm"; - } else { - axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); - CHECK_EQ(name_to_arg.count(axis_name), 0); - name_to_arg[axis_name] = arg; - } - } - - std::unordered_map div_factors; - std::vector r_new_args; - for (int i = new_names.size() - 1; i >= 0; --i) { - auto ori_iter_name = new_names[i]; - auto name_it = name_to_arg.find(ori_iter_name); - CHECK(name_it != name_to_arg.end()); - PrimExpr ori_arg = name_it->second; - - PrimExpr mod_factor = new_shape[i]; - - PrimExpr div_factor = 1; - if (div_factors.count(ori_iter_name)) { - div_factor = div_factors[ori_iter_name]; - } - div_factors[ori_iter_name] = div_factor * new_shape[i]; - - PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); - - r_new_args.push_back(new_arg); - } - - Array new_args(std::make_move_iterator(r_new_args.rbegin()), - std::make_move_iterator(r_new_args.rend())); - - return Call::make(call->type, call->name, new_args, call->call_type, - call->func, call->value_index); - } - } - return op_; - } - */ - private: const OperationMap >& placeholder_new_names_; const OperationMap >& placeholder_new_shapes_; @@ -1345,15 +1288,6 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level); *ret = Array{sch, return_tensors}; }); -/* -TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") -.set_body_typed([](const ComputeDAG& dag, const State& state) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); - return Array{sch, return_tensors}; -}); -*/ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index ef4c4632e9bf..010e5f3dc221 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -18,8 +18,9 @@ */ /*! - * \file ansor/loop_state.h - * \brief An IR (intermediate representation) for loop structures. + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. */ #include "loop_state.h" diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 2d64db11fc18..1b7bbc40bb31 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -27,10 +27,10 @@ * Basically this is a simplified TVM IR with schedule primitives. * We don't use the existing TVM IR because * 1. We want fast incremental change to the loop structures - * 2. We want serializable history for replay and backtracking - * 3. We may create some Macro schedule primitives + * 2. We want serializable transformation history for replay, backtracking, and mutation. + * 3. We may create some macro schedule primitives * - * After search is done, we will lower this IR to TVM IR with TVM schedule primitives. + * After the search is done, we will lower this IR to TVM IR with TVM schedule primitives. * Because we share a lot common objects during search, the transformation is * implemented in copy on write style. All objects are immutable, which is * similar to TVM IR. @@ -53,7 +53,8 @@ using namespace tvm::tir; /*! \brief The type of a stage */ enum StageType { - kPlaceholder, kCompute + kPlaceholder, // A placeholder stage + kCompute // A compute stage }; /*! \brief The type of compute location */ @@ -78,6 +79,7 @@ enum IteratorAnnotation { kTensorized }; +// forward declaration class Iterator; /*! @@ -91,7 +93,7 @@ class IteratorNode : public Object { IteratorType iter_type; IteratorAnnotation annotation; std::vector ori_iters; // The original iterators before fusion - std::string attr; + std::string attr; // Todo(jcf94): Document this void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); @@ -115,13 +117,12 @@ class Iterator : public ObjectRef { std::string attr = ""); TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(IteratorNode); }; /*! \brief Stage-level attributes */ struct StageAttributes { - int auto_unroll_max_step; - int storage_offset; + int auto_unroll_max_step; // The maximum steps for the pragma `auto_unroll_max_step` + int storage_offset; // The storage offset for the schedule primitive `storage_align` }; /*! @@ -130,11 +131,11 @@ struct StageAttributes { */ class StageNode : public Object { public: - te::Operation op; - StageType op_type; - std::vector iters; - ComputeAtType compute_at; - StageAttributes attrs; + te::Operation op; // The operator of this stage + StageType op_type; // The type of this stage + std::vector iters; // The iterators in this stage + ComputeAtType compute_at; // The compute location of this stage + StageAttributes attrs; // Other stage-level attributes void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index a044acfe5395..e99f41725077 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -341,8 +341,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->time_cost << ")"; }); -TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed([](SearchTask task, State state) { +TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) { return MeasureInput(task, state); }); @@ -359,8 +358,7 @@ TVM_REGISTER_GLOBAL("ansor.MeasureResult") }); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") -.set_body_typed([](const Builder& builder, - const Array& inputs, int verbose) { +.set_body_typed([](const Builder& builder, const Array& inputs, int verbose) { return builder->Build(inputs, verbose); }); @@ -397,6 +395,5 @@ TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") max_continous_error); }); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 51a48780813a..b86bf9490851 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -77,7 +77,6 @@ void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { - PrintTitle("Call search callbacks", verbose); for (const auto& callback : callbacks) { callback->callback(this); } diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc index 63f75cad1c83..c4365a391865 100644 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -67,6 +67,7 @@ State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, this->verbose = verbose; num_measure_per_iter_ = num_measure_per_iter; + PrintTitle("Call search callbacks", verbose); RunCallbacks(pre_search_callbacks); if (n_trials <= 1) { // no measurement is allowed @@ -94,7 +95,7 @@ State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, PrintTitle("Search", verbose); SearchOneRound(&best_states, num_random, &random_states); - // Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check + // Infer bound. This is necessary for computing the correct ToStr() for redundancy check cur_task->compute_dag.InferBound(&best_states); cur_task->compute_dag.InferBound(&random_states); @@ -218,10 +219,10 @@ void SketchSearchPolicyNode::PickStatesWithEpsGreedy( std::string state_str = pstate->ToStr(); if (measured_states_set_.count(state_str)) { continue; } - measured_states_set_.insert(state_str); + measured_states_set_.insert(std::move(state_str)); inputs->push_back(MeasureInput(cur_task, *pstate)); - measured_states_vector_.push_back(std::move(*pstate)); + measured_states_vector_.push_back(*pstate); } } @@ -274,7 +275,7 @@ void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in sketch generation +// The base class for derivation rules used in sketch generation class SketchGenerationRule { public: enum ConditionEnum { diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index 412d0afcca98..2d2f92ecbc20 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -32,9 +32,10 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia auto pop = s->stages[stage_id]->op.as(); CHECK(pop != nullptr); - auto no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); - std::set no_split_at_inner_name_set = no_split_name_pair.first; - std::set no_split_at_outer_name_set = no_split_name_pair.second; + const auto& no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); + const std::set& no_split_at_inner_name_set = no_split_name_pair.first; + const std::set& no_split_at_outer_name_set = no_split_name_pair.second; + size_t reduce_count = 0; for (const auto axis : pop->reduce_axis) { if (!no_split_at_inner_name_set.count(axis->var->name_hint) && @@ -52,6 +53,8 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia } } else if (auto ps = s->transform_steps[i].as()) { if (stage_id == ps->stage_id) { + // Assume SplitStep on reduction axes are always after SplitStep on spatial axes. + // TODO(jcf94): do not rely on this assumption if (reduce_count) { reduce_count--; } else { @@ -75,7 +78,7 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo } else if (tolower(c) == 'r') { reduce_levels.emplace_back(); } else { - LOG(FATAL) << "Invalid multi level tiling format: " << format; + LOG(FATAL) << "Invalid multi-level tiling format: " << format; } } size_t n_space = space_levels.size(); @@ -85,10 +88,10 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo State tmp_s = state; const Stage& stage = state->stages[stage_id]; - auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy - auto last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); - std::set no_split_at_inner_name_set = no_split_name_pair.first; - std::set no_split_at_outer_name_set = no_split_name_pair.second; + const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + const auto& last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); + const std::set& no_split_at_inner_name_set = no_split_name_pair.first; + const std::set& no_split_at_outer_name_set = no_split_name_pair.second; for (const auto& iter : state->stages[stage_id]->iters) { if (iter->iter_type == kSpace) { @@ -119,10 +122,10 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo } } } else if (iter->iter_type == kReduce) { - // for reduce iterator, split it into two iterators if (!no_split_at_inner_name_set.count(iter->name) && !no_split_at_outer_name_set.count(iter->name)) { CHECK_GE(n_reduce, 1); + if (n_reduce == 1) { reduce_levels[0].push_back(iter); } else { @@ -147,23 +150,27 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo if (!space_outer.empty()) { CHECK(!space_levels.empty()); space_levels.front().insert(space_levels.front().begin(), - space_outer.begin(), space_outer.end()); + std::make_move_iterator(space_outer.begin()), + std::make_move_iterator(space_outer.end())); } if (!space_inner.empty()) { CHECK(!space_levels.empty()); space_levels.back().insert(space_levels.back().begin(), - space_inner.begin(), space_inner.end()); + std::make_move_iterator(space_inner.begin()), + std::make_move_iterator(space_inner.end())); } if (!reduce_outer.empty()) { CHECK(!reduce_levels.empty()); reduce_levels.front().insert(reduce_levels.front().begin(), - reduce_outer.begin(), reduce_outer.end()); + std::make_move_iterator(reduce_outer.begin()), + std::make_move_iterator(reduce_outer.end())); } if (!reduce_inner.empty()) { CHECK(!reduce_levels.empty()); reduce_levels.back().insert(reduce_levels.back().begin(), - reduce_inner.begin(), reduce_inner.end()); + std::make_move_iterator(reduce_inner.begin()), + std::make_move_iterator(reduce_inner.end())); } std::vector order; @@ -198,7 +205,7 @@ State FollowTiling(const State& state, int stage_id, auto pop = state->stages[stage_id]->op.as(); CHECK(pop != nullptr); const Stage& stage = state->stages[stage_id]; - auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy const std::set& no_split_at_inner_name_set = no_split_name_pair.first; const std::set& no_split_at_outer_name_set = no_split_name_pair.second; int no_split_at_inner_name_in_stage_cnt = 0; @@ -266,6 +273,7 @@ State FollowTiling(const State& state, int stage_id, LOG(FATAL) << "Invalid iter type: " << iter->iter_type; } } + if (n_split == 3) { ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); } else if (n_split == 2) { @@ -406,13 +414,320 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen return tmp_s; } +State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, + const SearchTask& task, int verbose) { + // To make this mutation simple but promising, we only focus on a specific case that + // parallel was added to the outermost loop and the loop is generated by fusing other loops. + // In short, we mutate the step pattern of (fuse -> parallel). + + // Extract all parallel steps. + std::vector parallel_steps; + for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { + auto ps = old_state->transform_steps[s].as(); + if (!ps || ps->annotation != kParallel) { + continue; + } + parallel_steps.push_back(s); + } + if (parallel_steps.empty()) { + StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; + return State(); + } + + // Randomly pick one step. + int retry_ct = 0; + size_t step_id = 0; + size_t stage_id = 0; + do { + step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; + auto step = old_state->transform_steps[step_id].as(); + stage_id = step->stage_id; + + // Check assumptions. + auto iter_id = step->iter_id; + if (iter_id == 0 && step_id > 0 && old_state->transform_steps[step_id - 1].as()) { + break; + } + retry_ct++; + } while (retry_ct <= 3); + + if (retry_ct > 3) { + StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; + return State(); + } + + // Replay a new state until the picked fuse step. + State tmp_s = task->compute_dag.GetInitState(); + for (size_t s = 0; s < step_id - 1; ++s) { + auto step = old_state->transform_steps[s]; + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + tmp_s.DoStep(step, task->compute_dag); + } + + // Determine the fuse direction. + // 0: fuse less; 1: fuse more. + auto fuse_step = old_state->transform_steps[step_id - 1].as(); + std::vector fused_ids = fuse_step->fused_ids; + std::vector fuse_dir = {0.5, 1.0}; + + // The case we can only fuse more. + if (fused_ids.size() == 1) { + fuse_dir[0] = 0.0; + } + + // The cases that we cannot fuse the next iters. + if (old_state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0 || + tmp_s->stages[stage_id]->iters.size() == fused_ids.size() || + tmp_s->stages[stage_id]->iters[1]->iter_type == kReduce) { + // In case we cannot fuse less neither, give up. + if (fuse_dir[0] == 0.0) { + StdCout(verbose) << "Parallel mutation failed: Cannot fuse more or less iters" << std::endl; + return State(); + } + fuse_dir[0] = 1.0; + } + + int iter_offset = 0; + if (RandomChoose(fuse_dir, random_gen) == 0) { + StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; + fused_ids.pop_back(); + iter_offset = 1; + } else { + StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; + fused_ids.push_back(fused_ids.back() + 1); + iter_offset = -1; + } + + // Replay the mutated fused and annotation step. + auto new_fuse_step = FuseStep(stage_id, fused_ids); + tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); + tmp_s.DoStep(new_fuse_step, task->compute_dag); + tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[step_id]); + tmp_s.DoStep(old_state->transform_steps[step_id], task->compute_dag); + + // Replay the rest steps. + for (size_t s = step_id + 1; s < old_state->transform_steps.size(); ++s) { + auto step = old_state->transform_steps[s]; + if (step->stage_id == static_cast(stage_id)) { + // Since we change the loop structure, iter ID in later steps to the same stage + // has to be adjusted. + auto ps = step.as(); + if (ps) { + if (ps->iter_id == 0) { + step = AnnotationStep(ps->stage_id, 0, ps->annotation); + } else { + CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); + step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); + } + } else { + StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" + << std::endl; + return State(); + } + } + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + tmp_s.DoStep(step, task->compute_dag); + } + return tmp_s; +} + + +State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, + const SearchTask& task) { + // Extract all compute_at steps. + std::vector compute_at_steps; + for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { + if (auto ps = old_state->transform_steps[s].as()) { + const Stage& stage = old_state->stages[ps->stage_id]; + if (IsTiled(stage)) { + continue; + } + + if (NeedsMultilevelTiling(task, old_state, stage->op)) { + continue; + } + compute_at_steps.push_back(s); + } + } + if (compute_at_steps.empty()) { + return State(); + } + + // Randomly pick one step + size_t step_id = compute_at_steps[(*random_gen)() % compute_at_steps.size()]; + auto ps = old_state->transform_steps[step_id].as(); + CHECK(ps != nullptr); + const Stage& stage = old_state->stages[ps->stage_id]; + + // Randomly pick one tile level + int new_compute_at_stage_id; + int new_compute_at_iter_id; + + // Copied from InitPopulationChangeComputeLocation + { + std::unordered_set consumers; + GetConsumers(task, old_state, stage->op, &consumers); + if (consumers.empty()) { + return State(); + } + + int target_stage_id; + if (consumers.size() == 1) { + target_stage_id = OperationToStage(*consumers.begin(), old_state); + } else { + // check all consumers share a common root + int common_root_id = -1; + bool mismatch = false; + for (const auto& consumer : consumers) { + int consumer_stage_id = OperationToStage(consumer, old_state); + int root_id = -1; + if ((old_state)->stages[consumer_stage_id]->compute_at == kRoot) { + root_id = consumer_stage_id; + } else if ((old_state)->stages[consumer_stage_id]->compute_at == kIter) { + root_id = (old_state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; + } else { + LOG(FATAL) << "Invalid case"; + } + + if (common_root_id == -1) { + common_root_id = root_id; + } else { + if (common_root_id != root_id) { + mismatch = true; + break; + } + } + } + + if (mismatch) { + return State(); + } + target_stage_id = common_root_id; + } + + const Stage& target_stage = old_state->stages[target_stage_id]; + std::set to_unroll_name_set; + if (target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, + SearchPolicyNode::always_unroll_key); + } + + std::vector > candidates; + bool target_compute_at_other = target_stage->compute_at == kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // enumerate compute_at location at target_stage + int ct = 0; + for (size_t iter_id = 0; iter_id < target_stage->iters.size(); ++iter_id) { + const auto& target_iter = target_stage->iters[iter_id]; + if (target_iter->iter_type == kReduce) { + visited_reduce = true; + if (!target_is_tiled) { // do not go into reduce iter + break; + } + } else if (target_iter->iter_type == kSpace) { + if (visited_reduce) { // do not go into inner tile + break; + } + } + + if (to_unroll_name_set.count(target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_type == kSpace && + StrEndsWith(target_iter->name, ".0")) { + // skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, iter_id); + + if ((old_state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_stage_id, ct++))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = (old_state)->attach_map->stage_to_attach_iter.at( + target_stage_id).first; + const Stage& target_target_stage = (old_state)->stages[target_target_stage_id]; + if (target_target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, + SearchPolicyNode::always_unroll_key); + } else { + to_unroll_name_set.clear(); + } + + int ct = 0; + for (size_t iter_id = 0; iter_id < target_target_stage->iters.size(); ++iter_id) { + const auto& target_target_iter = target_target_stage->iters[iter_id]; + if (target_target_iter->iter_type == kReduce || + (old_state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, ct++))) { + break; + } + + if (to_unroll_name_set.count(target_target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.emplace_back(target_target_stage_id, iter_id); + } + } + + if (candidates.empty()) { + return State(); + } + + int choice = (*random_gen)() % (candidates.size()); + new_compute_at_stage_id = candidates[choice].first; + new_compute_at_iter_id = candidates[choice].second; + } + + // Replay a new state. + State tmp_s = task->compute_dag.GetInitState(); + for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { + if (s == step_id) { + tmp_s.CopyOnWrite()->transform_steps.push_back( + ComputeAtStep(ps->stage_id, new_compute_at_stage_id, new_compute_at_iter_id)); + } else { + tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[s]); + } + try { + tmp_s.DoStep(tmp_s->transform_steps.back(), task->compute_dag); + } catch (dmlc::Error &e) { + return State(); + } + } + + return tmp_s; +} + void PruneUndefined(std::vector* states) { size_t pt = 0; for (size_t i = 0; i < states->size(); ++i) { if (!(*states)[i].defined()) { continue; } - (*states)[pt++] = std::move((*states)[i]); + if (i != pt) { + (*states)[pt++] = std::move((*states)[i]); + } + pt++; } if (pt == 0) { diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 5f15397e7e90..107e2ee72521 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -79,8 +79,8 @@ inline std::set GetIterNameSetParam(const Map& a CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto names = attr_dict[key].as(); CHECK(names != nullptr); - for (auto name = names->begin(); name != names->end(); name++) { - ret.insert(name->as()->value); + for (const auto & name : *names) { + ret.insert(name.as()->value); } return ret; } @@ -284,9 +284,6 @@ inline bool HasCacheReadStage(const State& s, int stage_id) { return false; } -// Get all split step on spatial iterators -void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); - // Return whether the state did split/follow_split/follow_fused_split in stage_id inline bool HasSplitStep(const State& s, int stage_id) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { @@ -441,6 +438,9 @@ inline void PrintAllStates(const std::vector& states) { } } +// Get all split steps on spatial iterators for one stage +void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); + // Apply multi-level tiling structure according to a string format, // where "S" stands a space level, "R" stands for a reudciton level. // For example, if the format is "SSRSRS", the we will @@ -451,8 +451,7 @@ inline void PrintAllStates(const std::vector& states) { State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids); -// Apply tiling structure: space, space -// But use tile sizes from other SplitStep +// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split); @@ -464,6 +463,14 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, const std::vector& auto_unroll_configs); +// Randomly mutate the parallel degree of one stage. +State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, + const SearchTask& task, int verbose = 0); + +// Randomly mutate the computation location of one stage. +State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, + const SearchTask& task); + // GA: Crossover two states State CrossOverState(const State& p1, const State& p2); diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index c53fdcd0f792..0f270d105d73 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -121,7 +121,6 @@ class SearchTask : public ObjectRef { HardwareParams hardware_params); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(SearchTaskNode); }; } // namespace ansor diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index c026b9b6251a..d84c3c57dc86 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -583,22 +583,11 @@ std::pair BestMeasurePairInFile( return best_pair; } -TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); -}); - -TVM_REGISTER_GLOBAL("ansor.LogToFile") -.set_body_typed([](const std::string& filename) { +TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const std::string& filename) { return LogToFile(filename); }); -TVM_REGISTER_GLOBAL("ansor.LogReader") -.set_body_typed([](const std::string& filename) { +TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const std::string& filename) { return LogReader(filename); }); @@ -619,6 +608,15 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") } }); +TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); + TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") .set_body([](TVMArgs args, TVMRetValue *ret) { Array inputs = args[0]; @@ -672,6 +670,5 @@ TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") *ret = states; }); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index bd0a7f7165f6..e882a0495263 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -428,8 +428,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Compute At **********/ -ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, - int target_iter_id) { +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { auto node = make_object(); node->stage_id = stage_id; node->target_stage_id = target_stage_id; diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index edd71732b3e2..f8283b876f18 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -77,7 +77,6 @@ class ReorderStep : public Step { ReorderStep(int stage_id, const std::vector& after_ids); TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ReorderStepNode); }; /*! \brief Split step that corresponds to te::Stage::split with additional @@ -113,7 +112,6 @@ class SplitStep : public Step { bool inner_to_outer); TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitStepNode); }; /*! \brief Similar to SplitStepNode, but use split factor from another step @@ -149,7 +147,6 @@ class FollowSplitStep : public Step { FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowSplitStepNode); }; @@ -189,7 +186,6 @@ class FollowFusedSplitStep : public Step { int level, bool factor_or_nparts); TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowFusedSplitStepNode); }; /*! \brief Fuse step that corresponds to te::Stage::fuse */ @@ -218,7 +214,6 @@ class FuseStep : public Step { FuseStep(int stage_id, const std::vector& fused_ids); TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(FuseStepNode); }; /*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. @@ -250,10 +245,9 @@ class AnnotationStep : public Step { AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(AnnotationStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::compute_at */ +/*! \brief Compute at step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode: public StepNode { public: int target_stage_id; @@ -280,10 +274,9 @@ class ComputeAtStep : public Step { ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeAtStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::compute_root */ +/*! \brief Compute root step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: void ApplyToSchedule(std::vector *stages, @@ -307,10 +300,9 @@ class ComputeRootStep : public Step { explicit ComputeRootStep(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeRootStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::compute_inline */ +/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode: public StepNode { public: void ApplyToSchedule(std::vector *stages, @@ -334,7 +326,6 @@ class ComputeInlineStep : public Step { explicit ComputeInlineStep(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeInlineStepNode); }; /*! \brief Cache read step that corresponds to te::Schedule::cache_read */ @@ -366,10 +357,9 @@ class CacheReadStep : public Step { const std::vector& reader_stage_id); TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheReadStepNode); }; -/*! \brief Cache read step that corresponds to te::Schedule::cache_write +/*! \brief Cache write step that corresponds to te::Schedule::cache_write * \Note This step will cache_write all output tensors of target stage */ class CacheWriteStepNode: public StepNode { public: @@ -397,10 +387,9 @@ class CacheWriteStep : public Step { CacheWriteStep(int stage_id, std::string scope_name); TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheWriteStepNode); }; -/*! \brief Cache read step that corresponds to te::Schedule::pragma */ +/*! \brief Pragma step that corresponds to te::Schedule::pragma */ class PragmaStepNode: public StepNode { public: int iter_id; @@ -427,7 +416,6 @@ class PragmaStep : public Step { PragmaStep(int stage_id, int iter_id, std::string pragma_type); TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(PragmaStepNode); }; /*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ @@ -458,7 +446,6 @@ class RfactorStep : public Step { RfactorStep(int stage_id, int iter_id, int factor_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(RfactorStepNode); }; /*! \brief Storage align step that corresponds to te::Schedule::storage_align */ @@ -489,7 +476,6 @@ class StorageAlignStep : public Step { StorageAlignStep(int stage_id, int iter_id, int factor, int offset); TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(StorageAlignStepNode); }; /*! \brief Tensorize step that corresponds to te::Schedule::tensorize @@ -520,7 +506,6 @@ class TensorizeStep : public Step { TensorizeStep(int stage_id, int iter_id, std::string ti_func_name); TVM_DEFINE_OBJECT_REF_METHODS(TensorizeStep, Step, TensorizeStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(TensorizeStepNode); }; } // namespace ansor diff --git a/src/ansor/utils.h b/src/ansor/utils.h index cb90364b01b5..4e98bb907af9 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -81,13 +81,6 @@ struct hash > { namespace tvm { namespace ansor { -/*! \brief Macro to make it easy to define object ref type given node */ -#define TVM_DEFINE_OBJECT_REF(TypeName, ObjectName) \ - class TypeName : public ObjectRef { \ - public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ - }; \ - /*! \brief Macro to make it easy to define mutable object ref type given node */ #define TVM_DEFINE_MUTABLE_OBJECT_REF(TypeName, ObjectName) \ class TypeName : public ObjectRef { \ @@ -95,17 +88,6 @@ namespace ansor { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ }; \ -/*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. - */ -#define TVM_DEFINE_COW_OBJECT_REF(TypeName, BaseType, ObjectName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, ObjectName); \ - TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName); \ - }; - /********** Utilities for std::vector, std::set, std::string **********/ /*! \brief Get the first appearance index of elements in a vector */ template From 58601918b60ebf6bfcc57ec0a9c36c7da21c2de7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jun 2020 13:57:32 -0700 Subject: [PATCH 38/78] Improve loop state python API (stage_tensors -> stage_ops) (#41) * improve loop state python API (stage_tensors -> stage_ops) * fix --- python/tvm/ansor/loop_state.py | 324 ++++++++---------- .../python/unittest/test_ansor_compute_dag.py | 6 +- tests/python/unittest/test_ansor_feature.py | 4 +- .../python/unittest/test_ansor_loop_state.py | 14 +- 4 files changed, 153 insertions(+), 195 deletions(-) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 8560a57bc902..7aa5de0e9c1d 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -25,16 +25,17 @@ Basically this is a simplified TVM IR with schedule primitives. We don't use the existing TVM IR because 1. We want fast incremental change to the loop structures -2. We want serializable history for replay and backtracking +2. We want serializable transformation history for replay, backtracking, and mutation 3. We may create some new macro schedule primitives -After search is done, we will lower this IR to TVM IR with TVM's schedule primitives. +After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. Because we share a lot common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. """ import tvm._ffi +from tvm.te.tensor import Operation, Tensor from tvm.runtime import Object from . import _ffi_api @@ -80,43 +81,9 @@ def __init__(self, state_object, dag): self.state_object = state_object self.compute_dag = dag - self.stages_cache = None - self.stage_id_map = {} - self.__update_tensor_stage_map() - - def __getitem__(self, k): - if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - if isinstance(k, tvm.te.Tensor): - return self.stages_cache[self.stage_id_map[k.op]] - raise ValueError("Item must be Tensor") - - def __update_tensor_stage_map(self): - if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - for index, stage in enumerate(self.stages_cache): - self.stage_id_map[stage.op] = index - - def __insert_new_stage(self, new_stage_id): - new_stage_id = int(new_stage_id) - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - added_stage_tensor = self.stages_cache[new_stage_id].op.output(0) - - for key, value in self.stage_id_map.items(): - if value >= new_stage_id: - self.stage_id_map[key] = value + 1 - self.stage_id_map[added_stage_tensor.op] = new_stage_id - self.__update_tensor_stage_map() - - return added_stage_tensor - - def clear_cache(self): - self.stages_cache = None - - def copy(self): - state = State(self.state_object, self.compute_dag) - state.stage_id_map = self.stage_id_map.copy() - return state + self.stages_cache = None # A list to cache all stages + self.stage_id_map = {} # A dict maps operation to stage id + self._update_stage_id_map() @property def stages(self): @@ -130,15 +97,15 @@ def stages(self): return self.stages_cache @property - def stage_tensors(self): + def stage_ops(self): """ Returns ------- - Tensor + ops: List[Operation] """ if not self.stages_cache: self.stages_cache = _ffi_api.StateGetStages(self.state_object) - return [stage.op.output(0) for stage in self.stages_cache] + return [stage.op for stage in self.stages_cache] def transform_steps_size(self): """ Return the size of transform_steps @@ -149,30 +116,27 @@ def reorder(self, stage_id, order): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to reorder order : List[Iterator] Iterators in the expected order """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) - self.clear_cache() + self._clear_cache() def split(self, stage_id, iterator, lengths, inner_to_outer=True): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - lengths: List[Int] + lengths: List[int] The split factors - inner_to_outer: Bool + inner_to_outer: bool True to use `factor` to split from inner to outer, False to use `nparts` to split from outer to inner @@ -181,27 +145,24 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) - self.clear_cache() + self._clear_cache() return res def follow_split(self, stage_id, iterator, src_step_id, n_split): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - src_step_id : Int + src_step_id : int The index of the split step to follow in the history - n_split : Int + n_split : int The number of split level Returns @@ -209,14 +170,11 @@ def follow_split(self, stage_id, iterator, src_step_id, n_split): res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, src_step_id, n_split) - self.clear_cache() + self._clear_cache() return res def follow_fused_split(self, stage_id, iterator, src_step_ids, level, @@ -224,15 +182,15 @@ def follow_fused_split(self, stage_id, iterator, src_step_ids, level, """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - src_step_ids : List[Int] + src_step_ids : List[int] The indices of the split steps to follow in the history - level : Int + level : int Use the length in this split level - factor_or_nparts : Bool + factor_or_nparts : bool True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner @@ -241,22 +199,19 @@ def follow_fused_split(self, stage_id, iterator, src_step_ids, level, res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, iterator, src_step_ids, level, factor_or_nparts) - self.clear_cache() + self._clear_cache() return res def fuse(self, stage_id, iters): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to fuse iters : List[Iterator] The iterators to be fused @@ -266,20 +221,17 @@ def fuse(self, stage_id, iters): res_it : Iterator The fused Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) - self.clear_cache() + self._clear_cache() return res def vectorize(self, stage_id, iterator): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to vectorize iterator : Iterator The iterator to be vectorized @@ -289,20 +241,17 @@ def vectorize(self, stage_id, iterator): res_it : Iterator The vectorized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) - self.clear_cache() + self._clear_cache() return res def parallel(self, stage_id, iterator): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to parallel iterator : Iterator The iterator to be parallelized @@ -312,24 +261,21 @@ def parallel(self, stage_id, iterator): res_it : Iterator The parallelized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) - self.clear_cache() + self._clear_cache() return res def unroll(self, stage_id, iterator, max_unroll=-1): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to unroll iterator : Iterator The iterator to be unrolled - max_unroll: Int + max_unroll: int The maximum length of the iterator that can be unrolled Returns @@ -337,21 +283,18 @@ def unroll(self, stage_id, iterator, max_unroll=-1): res_it : Iterator The unrolled Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, max_unroll) - self.clear_cache() + self._clear_cache() return res def bind_thread(self, stage_id, iterator, thread_name): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to bind iterator : Iterator The iterator to be bound @@ -372,201 +315,167 @@ def bind_thread(self, stage_id, iterator, thread_name): } thread_id = trans_table[thread_name] - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, thread_id) - self.clear_cache() + self._clear_cache() return res def compute_at(self, stage_id, target_stage_id, target_iter): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of source stage - target_stage_id : Int + target_stage_id : Union[int, Operation, Tensor] The index of the target stage of compute_at target_iter : Iterator The target Iterator of compute_at """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") - if isinstance(target_stage_id, tvm.te.Tensor): - target_stage_id = self.stage_id_map[target_stage_id.op] - elif not isinstance(target_stage_id, int): - raise ValueError("target_stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) + target_stage_id = self._resolve_stage_id(target_stage_id) self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, target_stage_id, target_iter) - self.clear_cache() + self._clear_cache() def compute_root(self, stage_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to compute root """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) - self.clear_cache() + self._clear_cache() def compute_inline(self, stage_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to compute inline """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) - self.clear_cache() + self._clear_cache() def cache_read(self, stage_id, scope_name, reader_stage_ids): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do cache_read - scope_name : Str - reader_stage_ids : List[Int] + scope_name : str + reader_stage_ids : List[int] Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) + if isinstance(reader_stage_ids, list): tmp_list = [] for reader_stage_id in reader_stage_ids: - if isinstance(reader_stage_id, tvm.te.Tensor): - tmp_list.append(self.stage_id_map[reader_stage_id.op]) - elif isinstance(reader_stage_id, int): - tmp_list.append(reader_stage_id) - else: - raise ValueError("reader_stage_id must be Tensor or Int") + tmp_list.append(self._resolve_stage_id(reader_stage_id)) reader_stage_ids = tmp_list else: - raise ValueError("reader_stage_ids must be list of Tensor or Int") + raise ValueError("reader_stage_ids must be list of Tensor or int") self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, scope_name, reader_stage_ids, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def cache_write(self, stage_id, scope_name): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do cache read - scope_name : Str + scope_name : str Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, scope_name, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def pragma(self, stage_id, iterator, pragma_type): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to add pragma iterator : Iterator The iterator to add pragma - pragma_type : Str + pragma_type : str """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, pragma_type) - self.clear_cache() + self._clear_cache() def rfactor(self, stage_id, iterator, factor_iter_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do reduction factor iterator : Iterator - factor_iter_id : Int + factor_iter_id : int Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, iterator, factor_iter_id, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def storage_align(self, stage_id, iterator, factor, offset): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do storage align iterator : Iterator - factor : Int - offset : Int + factor : int + offset : int """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, factor, offset) - self.clear_cache() + self._clear_cache() def tensorize(self, stage_id, iterator, ti_func_name): """ The `ti_func_name` corresponds to a global registered funcion - that returns a TensorIntrin + that returns a Tensorintrin Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do storage align iterator : Iterator - The target iterator - ti_func_name : Str + The iterator to be tensorized + ti_func_name : str Tensorize intrinsic function name Returns @@ -574,17 +483,66 @@ def tensorize(self, stage_id, iterator, ti_func_name): res_it : Iterator The tensorized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateTensorize(self.state_object, stage_id, iterator, ti_func_name) - self.clear_cache() + self._clear_cache() return res + def _resolve_stage_id(self, stage_id): + if isinstance(stage_id, Operation): + return self.stage_id_map[stage_id] + elif isinstance(stage_id, tvm.te.Tensor): + return self.stage_id_map[stage_id.op] + elif isinstance(stage_id, int): + return stage_id + else: + raise ValueError("Invalid stage_id") + + def _update_stage_id_map(self): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + for index, stage in enumerate(self.stages_cache): + self.stage_id_map[stage.op] = index + + def _insert_new_stage(self, new_stage_id): + new_stage_id = int(new_stage_id) + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + added_op = self.stages_cache[new_stage_id].op + + # Add a new stage will change all ops. But we still want to use the old ops to index stages, + # So we keep updating them and do not remove the old ops. + + # Update stage_id_map for old ops, so we can still use the old ops to index stages. + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_op] = new_stage_id + + # Update stage_id_map for new ops + self._update_stage_id_map() + + return added_op + + def _clear_cache(self): + self.stages_cache = None + + def copy(self): + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state + + def __getitem__(self, key): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + if isinstance(key, Tensor): + key = key.op + if isinstance(key, Operation): + return self.stages_cache[self.stage_id_map[key]] + raise ValueError("Item must be Tensor") + def __str__(self): return str(self.state_object) diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 313dc1f89902..0768f82b805a 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,9 +34,9 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global = s.stage_tensors[1] - B_global = s.stage_tensors[3] - C_global = s.stage_tensors[4] + A_global = s.stage_ops[1] + B_global = s.stage_ops[3] + C_global = s.stage_ops[4] assert s[B_global].iters[0].range.extent == 512 assert s[B_global].iters[1].range.extent == 16 assert s[A_global].iters[0].range.extent == 1 diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index bcc7683b3f4a..705556c65edf 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -33,7 +33,7 @@ def fequal(a, b): def test_cpu_matmul(): dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() - C = s.stage_tensors[2] + C = s.stage_ops[2] i, j, k = s[C].iters io, ii = s.split(C, i, [16]) @@ -42,7 +42,7 @@ def test_cpu_matmul(): s.vectorize(C, ji) s.parallel(C, io) s.parallel(C, jo) - s.unroll(2, k) + s.unroll(C, k) target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 87688e276469..d90be1a78421 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -115,14 +115,14 @@ def test_compute_at_root_inline(): s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 - conv = s0.stage_tensors[3] + conv = s0.stage_ops[3] # bias = 4 - bias_add = s0.stage_tensors[5] + bias_add = s0.stage_ops[5] # bn_scale = 6 - bn_mul = s0.stage_tensors[7] + bn_mul = s0.stage_ops[7] # bn_offset = 8 - bn_add = s0.stage_tensors[9] - relu = s0.stage_tensors[10] + bn_add = s0.stage_ops[9] + relu = s0.stage_ops[10] s0.compute_inline(bn_add) s0.compute_inline(bn_mul) @@ -193,8 +193,8 @@ def test_cache_read_write(): dag = ansor.ComputeDAG([data, kernel_data, add]) s0 = dag.get_init_state() - pad_temp = s0.stage_tensors[1] - kernel_split = s0.stage_tensors[3] + pad_temp = s0.stage_ops[1] + kernel_split = s0.stage_ops[3] # 0: init state ori_its = s0[add].iters From 14a19cd9597809801d570228818aea61b7082072 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 24 Jun 2020 13:22:45 +0800 Subject: [PATCH 39/78] ComputeDAG bug fix & Add Custom TensorCore Matmul Example (#42) * Bug Fix * Sample example of Custom TensorCore Matmul --- scripts/common.py | 34 ++++---- scripts/tune_test.py | 181 +++++++++++++++++++++++++++++++++++++-- src/ansor/compute_dag.cc | 12 ++- 3 files changed, 199 insertions(+), 28 deletions(-) diff --git a/scripts/common.py b/scripts/common.py index ac25b28e55b1..e9cf58e128bb 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -81,25 +81,25 @@ def add_mn(M, N): @register_workload_func def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', tensor_core_support=False): - A = te.placeholder((N, K), name='A', dtype=in_type) - B = te.placeholder((K, M), name='B', dtype=in_type) - k = te.reduce_axis((0, K), name='k') - if in_type == out_type: - if not (in_type == 'float16' and out_type == 'float16'): - tensor_core_support = False - C = te.compute((N, M), - lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), - name='C', - attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) - else: + if tensor_core_support: + A = te.placeholder((N // 16, K // 16, 16, 16), name='A', dtype=in_type) + B = te.placeholder((K // 16, M // 16, 16, 16), name='B', dtype=in_type) + k = te.reduce_axis((0, K // 16), name='k') + kk = te.reduce_axis((0, 16), name='kk') if not ((in_type == 'float16' and out_type == 'float32') or \ - (in_type == 'int8' and out_type == 'int32')): - tensor_core_support = False + (in_type == 'int8' and out_type == 'int32')): + raise ValueError + C = te.compute((N // 16, M // 16, 16, 16), + lambda i, j, ii, jj: te.sum(A[i][k][ii][kk].astype(out_type) * B[k][j][kk][jj].astype(out_type), + axis=[k, kk]), + name='C') + else: + A = te.placeholder((N, K), name='A', dtype=in_type) + B = te.placeholder((K, M), name='B', dtype=in_type) + k = te.reduce_axis((0, K), name='k') C = te.compute((N, M), - lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), - axis=[k]), - name='C', - attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) + lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), + name='C') return [A, B, C] diff --git a/scripts/tune_test.py b/scripts/tune_test.py index c98da3eca53b..6b39cf5e7865 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -24,14 +24,169 @@ import numpy as np import tvm -from tvm import ansor +from tvm import ansor, te from tvm.ansor.utils import request_remote from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool +def tensor_core_meet_condition(meta_policy, state, stage_id): + pass + +def intrin_wmma_load_matrix(scope): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float16') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + C = te.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, n, n, BC.elem_offset // 256, + BA.access_ptr('r'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +@tvm._ffi.register_func +def intrin_wmma_load_matrix_a(): + return intrin_wmma_load_matrix("wmma.matrix_a") + +@tvm._ffi.register_func +def intrin_wmma_load_matrix_b(): + return intrin_wmma_load_matrix("wmma.matrix_b") + +@tvm._ffi.register_func +def intrin_wmma_gemm(): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float16') + B = te.placeholder((n, n), name='B', dtype='float16') + k = te.reduce_axis((0, n), name="k") + C = te.compute((n, n), + lambda ii, jj: + te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + return ib.get() + + def update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // 256, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 256)) + return ib.get() + + return update(), init(), update() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + +@tvm._ffi.register_func +def intrin_wmma_store_matrix(): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float32') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) + C = te.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, n, n, BA.elem_offset // 256, + BC.access_ptr('w'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +def tensor_core_apply(meta_policy, state, stage_id): + ret = [] + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) + + A, B, C = meta_policy.cur_task.compute_dag.ops + + C_local = state.cache_write(C, "wmma.accumulator") + + its0 = state.split(C_local, state[C_local].iters[0], [None, None]) + split_step0 = state.transform_steps_size() - 1 + its1 = state.split(C_local, state[C_local].iters[3], [None, None]) + split_step1 = state.transform_steps_size() - 1 + its2 = state.split(C_local, state[C_local].iters[8], [None]) + + state.reorder(C_local, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its2[0], its2[1], + state[C_local].iters[6], + state[C_local].iters[7], + state[C_local].iters[10]]) + state.fuse(C_local, [state[C_local].iters[0], state[C_local].iters[1]]) + state.fuse(C_local, [state[C_local].iters[1], state[C_local].iters[2]]) + state.fuse(C_local, [state[C_local].iters[2], state[C_local].iters[3]]) + + its0 = state.follow_split(C, state[C].iters[0], split_step0, 2) + its1 = state.follow_split(C, state[C].iters[3], split_step1, 2) + state.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + state[C].iters[6], state[C].iters[7]]) + state.fuse(C, [state[C].iters[0], state[C].iters[1]]) + state.fuse(C, [state[C].iters[1], state[C].iters[2]]) + local_write_pos = state.fuse(C, [state[C].iters[2], state[C].iters[3]]) + state.compute_at(C_local, C, local_write_pos) + shared_read_pos = state[C_local].iters[3] + local_read_pos = state[C_local].iters[4] + state.bind_thread(C, state[C].iters[0], "blockIdx.x") + state.bind_thread(C, state[C].iters[1], "vthread") + state.bind_thread(C, state[C].iters[2], "threadIdx.x") + + B_shared = state.cache_read(B, "shared", [C_local]) + B_local = state.cache_read(B_shared, "wmma.matrix_b", [C_local]) + state.compute_at(B_shared, C_local, shared_read_pos) + state.compute_at(B_local, C_local, local_read_pos) + + it = state.fuse(B_shared, state[B_shared].iters[:]) + its = state.split(B_shared, it, [4]) # vectorize add a callback check function + state.vectorize(B_shared, its[1]) + its = state.follow_fused_split(B_shared, its[0], [split_step0, split_step1], 1, True) + state.bind_thread(B_shared, its[1], "threadIdx.x") + + A_shared = state.cache_read(A, "shared", [C_local]) + A_local = state.cache_read(A_shared, "wmma.matrix_a", [C_local]) + state.compute_at(A_shared, C_local, shared_read_pos) + state.compute_at(A_local, C_local, local_read_pos) + + it = state.fuse(A_shared, state[A_shared].iters[:]) + its = state.split(A_shared, it, [4]) # vectorize add a callback check function + state.vectorize(A_shared, its[1]) + its = state.follow_fused_split(A_shared, its[0], [split_step0, split_step1], 1, True) + state.bind_thread(A_shared, its[1], "threadIdx.x") + + state.tensorize(A_local, state[A_local].iters[-2], "intrin_wmma_load_matrix_a") + state.tensorize(B_local, state[B_local].iters[-2], "intrin_wmma_load_matrix_b") + state.tensorize(C_local, state[C_local].iters[-3], "intrin_wmma_gemm") + state.tensorize(C, state[C].iters[-2], "intrin_wmma_store_matrix") + + print(state) + + ret.append([state.state_object, -1]) + return ret + def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, - rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10): + rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10, + tensor_core_matmul=False): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -52,13 +207,16 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose config_threadpool = remote.get_function('runtime.config_threadpool') config_threadpool(0, rpc_num_threads) + pre_search_callbacks = [ansor.PreloadMeasuredStates(log_file)] + if tensor_core_matmul: + pre_search_callbacks.append(ansor.PreloadCustomSketchRule(tensor_core_meet_condition, tensor_core_apply)) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, num_measure_per_iter=num_measure_per_iter, verbose=verbose, builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) + pre_search_callbacks=pre_search_callbacks) return tune_option, measure_ctx @@ -113,10 +271,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, model.load(load_model_file) elif load_log_file: model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() - else: - raise ValueError("Invalid model: " + model_type) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) if policy == 'sketch': policy = ansor.SketchSearchPolicy(program_cost_model=model) @@ -200,11 +358,18 @@ def objective_func(costs): load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) + # Special check for tensor core + wkl_key = args.wkl + wkl_key = wkl_key.split("-") + tensor_core_matmul = False + if wkl_key[0] == "matmul" and wkl_key[6] == "tc": + tensor_core_matmul = True + tune_option, measure_ctx = create_tune_option(target, log_file, args.n_trials, args.num_measure_per_iter, args.verbose, args.n_parallel, args.build_timeout, args.local_measure, args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, - args.ndk_cc) + args.ndk_cc, tensor_core_matmul=tensor_core_matmul) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index ee87318cdd84..9e6da6ff6f3b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -569,13 +569,11 @@ State ComputeDAG::GetInitState() const { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; - node->tensors = std::move(tensors); node->access_analyzer = AccessAnalyzer(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); node->init_state = State(node->ops); - data_ = std::move(node); } @@ -587,7 +585,15 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { } else { LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; } - ComputeDAG(std::move(tens)); + + auto node = make_object(); + FlopEstimator estimator; + node->tensors = std::move(tens); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); + node->init_state = State(node->ops); + data_ = std::move(node); } std::string BaseName(const std::string& str) { From b012e279419c12591a7642f9b47d8cd6d4bfd65d Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 13:58:55 +0800 Subject: [PATCH 40/78] Rever Commits, Start to build minimum Ansor system --- docs/conf.py | 1 - include/tvm/relay/attrs/transform.h | 13 - include/tvm/relay/transform.h | 14 - include/tvm/runtime/c_runtime_api.h | 23 - include/tvm/runtime/device_api.h | 3 +- include/tvm/runtime/ndarray.h | 12 +- python/tvm/ansor/__init__.py | 15 +- python/tvm/ansor/auto_schedule.py | 86 -- python/tvm/ansor/compute_dag.py | 23 +- python/tvm/ansor/cost_model/__init__.py | 3 +- python/tvm/ansor/cost_model/cost_model.py | 31 - python/tvm/ansor/cost_model/xgb_model.py | 474 -------- python/tvm/ansor/dispatcher.py | 299 ----- python/tvm/ansor/env.py | 25 - python/tvm/ansor/feature.py | 150 --- python/tvm/ansor/loop_state.py | 319 ----- python/tvm/ansor/measure.py | 3 +- python/tvm/ansor/relay_integration.py | 241 ---- python/tvm/ansor/task_scheduler.py | 299 ----- python/tvm/relay/backend/compile_engine.py | 5 +- python/tvm/relay/build_module.py | 7 - python/tvm/relay/op/_transform.py | 2 - python/tvm/relay/op/op_attrs.py | 3 - python/tvm/relay/op/strategy/x86.py | 62 +- python/tvm/relay/op/transform.py | 21 - python/tvm/relay/testing/dqn.py | 27 +- python/tvm/relay/testing/resnet.py | 22 +- python/tvm/runtime/ndarray.py | 33 - python/tvm/te/tensor.py | 8 +- scripts/common.py | 1034 ----------------- scripts/shape_configs.py | 247 ---- scripts/tune_network.py | 405 ------- scripts/tune_op_subgraph.py | 602 ---------- scripts/tune_test.py | 394 ------- src/ansor/compute_dag.cc | 5 +- src/ansor/search_task.cc | 59 - src/arith/rewrite_simplify.cc | 71 +- src/relay/analysis/type_solver.cc | 1 - src/relay/backend/build_module.cc | 32 - src/relay/backend/compile_engine.cc | 5 - src/relay/backend/compile_engine.h | 3 - src/relay/op/tensor/transform.cc | 54 - src/relay/transforms/defuse_ops.cc | 91 -- .../transforms/kernel_layout_transform.cc | 66 -- .../transforms/kernel_layout_transform.h | 102 -- src/relay/transforms/pattern_util.h | 2 - src/runtime/cuda/cuda_device_api.cc | 4 - src/runtime/ndarray.cc | 80 +- src/runtime/opencl/opencl_device_api.cc | 3 - src/runtime/rpc/rpc_module.cc | 30 - src/runtime/threading_backend.cc | 9 +- src/te/schedule/schedule_dataflow_rewrite.cc | 66 +- src/tir/analysis/verify_gpu_code.cc | 44 +- src/tir/transforms/unroll_loop.cc | 20 +- tests/python/unittest/test_ansor_feature.py | 150 --- .../python/unittest/test_ansor_loop_state.py | 540 +-------- tests/python/unittest/test_ansor_measure.py | 18 - .../unittest/test_ansor_relay_integration.py | 114 -- .../unittest/test_ansor_search_policy.py | 85 -- .../unittest/test_ansor_task_scheduler.py | 52 - .../test_tir_transform_unroll_loop.py | 24 - topi/include/topi/transform.h | 69 -- topi/python/topi/nn/conv2d.py | 39 +- tutorials/ansor/README.txt | 4 - tutorials/ansor/tune_conv2d_cuda.py | 179 --- tutorials/ansor/tune_simple_subgraph.py | 193 --- tutorials/autotvm/README.txt | 4 +- 67 files changed, 86 insertions(+), 7038 deletions(-) delete mode 100644 python/tvm/ansor/cost_model/xgb_model.py delete mode 100644 python/tvm/ansor/dispatcher.py delete mode 100644 python/tvm/ansor/env.py delete mode 100644 python/tvm/ansor/feature.py delete mode 100644 python/tvm/ansor/relay_integration.py delete mode 100644 python/tvm/ansor/task_scheduler.py delete mode 100644 scripts/common.py delete mode 100644 scripts/shape_configs.py delete mode 100644 scripts/tune_network.py delete mode 100644 scripts/tune_op_subgraph.py delete mode 100644 scripts/tune_test.py delete mode 100644 src/relay/transforms/defuse_ops.cc delete mode 100644 src/relay/transforms/kernel_layout_transform.cc delete mode 100644 src/relay/transforms/kernel_layout_transform.h delete mode 100644 tests/python/unittest/test_ansor_feature.py delete mode 100644 tests/python/unittest/test_ansor_relay_integration.py delete mode 100644 tests/python/unittest/test_ansor_task_scheduler.py delete mode 100644 tutorials/ansor/README.txt delete mode 100644 tutorials/ansor/tune_conv2d_cuda.py delete mode 100644 tutorials/ansor/tune_simple_subgraph.py diff --git a/docs/conf.py b/docs/conf.py index 5826526d55b0..7ece63bd7aa8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -198,7 +198,6 @@ '../tutorials/language', '../tutorials/optimize', '../tutorials/autotvm', - '../tutorials/ansor', '../tutorials/dev', '../tutorials/topi', '../tutorials/deployment', diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 95476ed61bdd..750a8a43163c 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -296,19 +296,6 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for KernelLayoutTransform operator */ -struct KernelLayoutTransformAttrs : public tvm::AttrsNode { - std::string src_layout; - std::string dst_layout; - - TVM_DECLARE_ATTRS(KernelLayoutTransformAttrs, "relay.attrs.KernelLayoutTransformAttrs") { - TVM_ATTR_FIELD(src_layout) - .describe("The source layout of the tensor. (e.g. 1N32C112H112W)"); - TVM_ATTR_FIELD(dst_layout) - .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)"); - } -}; - /*! \brief Attributes for ShapeOf operator */ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 5f5d9b643633..1b8b31aee5d1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -277,20 +277,6 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); -/*! - * \brief Alternate the layouts of kernels. - * - * \return The pass. - */ -TVM_DLL Pass KernelLayoutTransform(); - -/*! - * \brief The reverse of FuseOps. - * - * \return The pass. - */ -TVM_DLL Pass DeFuseOps(); - /*! * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5a32ac7d3d9f..213c7059a5f9 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -384,29 +384,6 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); -/*! - * \brief Allocate a nd-array's memory of non-empty values, - * including space of shape, of given spec. - * - * \param shape The shape of the array, the data content will be copied to out - * \param ndim The number of dimension of the array. - * \param dtype_code The type code of the dtype - * \param dtype_bits The number of bits of dtype - * \param dtype_lanes The number of lanes in the dtype. - * \param device_type The device type of context - * \param device_id The device id of context. - * \param out The output handle. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMArrayAllocNonEmpty(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out); - /*! * \brief Free the TVM Array. * \param handle The array handle to be freed. diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 9b2eb6be2160..421811a52c3b 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -44,8 +44,7 @@ enum DeviceAttrKind : int { kMaxClockRate = 6, kMultiProcessorCount = 7, kMaxThreadDimensions = 8, - kGcnArch = 9, - kMaxRegistersPerBlock = 10 + kGcnArch = 9 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 9cc66a371974..e69d802652fd 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -138,17 +138,7 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, - DLDataType dtype, DLContext ctx); - /*! - * \brief Create an NDArray with non-empty values. - * \param shape The shape of the new array. - * \param dtype The data type of the new array. - * \param ctx The context of the Array. - * \return The created Array - */ - TVM_DLL static NDArray NonEmpty(std::vector shape, - DLDataType dtype, DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index edade490018c..ccd8f27b71c1 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -22,24 +22,15 @@ from . import serialization from . import loop_state from . import utils -from . import feature from . import workload_registry -from . import task_scheduler # Shortcut -from .compute_dag import ComputeDAG, LayoutRewriteLevel -from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \ - PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule +from .compute_dag import ComputeDAG +from .auto_schedule import SearchTask, TuneOption, HardwareParams, \ + auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel -from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file from .workload_registry import register_workload_func, \ workload_key_to_dag, make_workload_key_func -from .task_scheduler import TaskScheduler, SimpleTaskScheduler -from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext -from .relay_integration import extract_from_program, extract_from_multiple_program, \ - finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi -from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 4497bb400703..37e622018658 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -82,96 +82,10 @@ def set_verbose(self, verbose): def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) - -@tvm._ffi.register_object("ansor.SketchSearchPolicy") -class SketchSearchPolicy(SearchPolicy): - """ The search policy that searches in a hierarchical search space defined by sketches. - The policy randomly samples programs from the space defined by sketches - and use evolutionary search to fine-tune them. - - Parameters - ---------- - program_cost_model: CostModel - Cost model for programs - params: int - Parameters of the search policy. See `src/ansor/search_policy/sketch_search_policy.h` - to find the definitions. See code below to find the default values - seed: int - Random seed - """ - def __init__(self, - program_cost_model, - params=None, - seed=None): - # set default parameters - default_params = { - "eps_greedy": 0.05, - - 'evolutionary_search_population': 2048, - 'evolutionary_search_num_iters': 15, - "evolutionary_search_mutation_prob": 0.85, - "evolutionary_search_use_measured_ratio": 0.2, - - 'cpu_multi_level_tiling_structure': 'SSRSRS', - 'gpu_multi_level_tiling_structure': 'SSSRRSRS', - - 'disable_change_compute_location': 0, - } - - if params is None: - params = default_params - else: - for key, value in default_params.items(): - if key not in params: - params[key] = value - - self.__init_handle_by_constructor__( - _ffi_api.SketchSearchPolicy, program_cost_model, params, - seed or random.randint(1, 1 << 30)) - - @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): """Callback function before or after search process""" - -@tvm._ffi.register_object("ansor.PreloadMeasuredStates") -class PreloadMeasuredStates(SearchCallback): - """ A SearchCallback to load measured states from the log file for a search policy. - This can resume the state of the search policy. - - Parameters - ---------- - filename: str - """ - def __init__(self, filename: str): - self.__init_handle_by_constructor__( - _ffi_api.PreloadMeasuredStates, filename) - - -@tvm._ffi.register_object("ansor.PreloadCustomSketchRule") -class PreloadCustomSketchRule(SearchCallback): - """ - A SearchCallback for SketchSearchPolicy that allowing users to add - custom sketch rule. - - Notes - ----- - This is an advanced feature. Make sure you're clear how it - works and this should only be used in SketchSearchPolicy. - - Parameters - ---------- - meet_condition_func: Function - A function with `(policy, state, stage_id) -> int` - apply_func: Function - A function with `(policy, state, stage_id) -> [[State, int], ...]` - """ - def __init__(self, meet_condition_func, apply_func): - self.__init_handle_by_constructor__( - _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func) - - @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): """ The options for tuning diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 6304c7bb0e0a..acfec66a166a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -23,13 +23,6 @@ from . import _ffi_api -class LayoutRewriteLevel(object): - NO_REWRITE = 0 # No layout rewrite - PLACEHOLDER_REWRITE = 1 # Only rewrite layout of placeholder in the compute dag - COMPUTE_REWRITE = 2 # Only rewrite compute body for new layout in the compute dag - BOTH_REWRITE = 3 # Rewrite both placeholder and compute body in the compute dag - - @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): """ @@ -51,7 +44,7 @@ def get_init_state(self): """ return State(_ffi_api.ComputeDAGGetInitState(self), self) - def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): + def apply_steps_from_state(self, state): """ Apply transform steps according to the history of a state @@ -97,17 +90,3 @@ def infer_bound_from_state(self, state): """ state_obj = state if isinstance(state, StateObject) else state.state_object return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) - - def rewrite_layout_from_state(self, state: State): - """ - Rewrite the layout according to the transform steps in the history of a state - - Parameters - ---------- - state : StateObject - - Returns - ------- - state : StateObject - """ - return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index 56e4a5f9128b..1454da451b61 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -17,5 +17,4 @@ # pylint: disable=unused-import, redefined-builtin """ Cost model that estimates the performance of programs """ -from .cost_model import RandomModel -from .xgb_model import XGBModel +from .cost_model import RandomModel \ No newline at end of file diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index fbfc8242488b..605db14c19c3 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -44,34 +44,3 @@ def random_number(n, return_ptr): return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = np.random.uniform(0, 1, (n,)) - - -@tvm._ffi.register_object("ansor.PythonBasedModel") -class PythonBasedModel(CostModel): - """Base class for cost models implemented in python""" - def __init__(self): - def update_func(inputs, results): - self.update(inputs, results) - - def predict_func(task, states, return_ptr): - return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) - array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(len(states),)) - array_wrapper[:] = self.predict(task, states) - - def predict_stage_func(task, states, return_ptr): - ret = self.predict_stages(task, states) - return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) - array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape) - array_wrapper[:] = ret - - self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func, - predict_func, predict_stage_func) - - def update(self, inputs, results): - raise NotImplementedError - - def predict(self, task, states): - raise NotImplementedError - - def predict_stages(self, task, states): - raise NotImplementedError diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py deleted file mode 100644 index 42af17daae2c..000000000000 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ /dev/null @@ -1,474 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Cost model based on xgboost""" -import multiprocessing -import logging -from collections import defaultdict - -import numpy as np -import xgboost as xgb - -from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve -from .cost_model import PythonBasedModel -from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states -from ..serialization import LogReader - -logger = logging.getLogger('ansor') - -class XGBDMatrixContext: - """Context to hold additional attributes of xgb.DMatrix""" - def __init__(self): - self.context_dict = defaultdict(dict) - - def get(self, key, matrix, default=None): - return self.context_dict[key].get(matrix.handle.value, default) - - def put(self, key, matrix, value): - self.context_dict[key][matrix.handle.value] = value - -dmatrix_context = XGBDMatrixContext() - -class XGBModel(PythonBasedModel): - """Train a XGBoost model to predict the runtime cost of a program. - The cost of a program = the sum of the costs of all stages in this program. - i.e. Cost(p) = cost_s0 + cost_s1 + ... + cost_sn, where cost_si is the cost of Stage i - - The xgboost model makes prediction per stage, then we sum them up. - The final predction made by this class is normalized throughtput (from 0 to 1, larger is better) - - To support this stage decomposition, we have to implement a custom loss function for - XGBoost, which is the `pack_sum` in the code below. - """ - def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): - self.xgb_params = { - 'max_depth': 10, - 'gamma': 0.001, - 'min_child_weight': 0, - 'eta': 0.2, - # todo(lmzheng): automatically decrease learning rate when the loss is too large - - 'n_gpus': 0, - 'nthread': multiprocessing.cpu_count() // 2, - 'verbosity': 0, - 'seed': seed or 43, - 'disable_default_eval_metric': 1 - } - self.bst = None - self.plan_size = 32 - self.num_warmup_sample = num_warmup_sample - self.verbose_eval = verbose_eval - - super().__init__() - - # measurement input/result pairs - self.inputs = [] - self.results = [] - self.inputs_feature_cache = [] - - def update(self, inputs, results): - if len(inputs) <= 0: - return - - self.inputs.extend(inputs) - self.results.extend(results) - - # extract feature - n_cached = len(self.inputs_feature_cache) - features, normalized_throughputs, task_ids = \ - get_per_stmt_features_from_measure_pairs(self.inputs, self.results, - skip_first_n_feature_extraction=n_cached) - if n_cached > 0: - features = list(features) - features[:n_cached] = self.inputs_feature_cache - features = np.array(features) - self.inputs_feature_cache = features - dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, - task_ids, normalized_throughputs) - - # train xgb model - self.bst = xgb.train(self.xgb_params, dtrain, - num_boost_round=10000, - obj=pack_sum_square_error, - callbacks=[custom_callback( - stopping_rounds=50, - metric='tr-p-rmse', - fevals=[ - pack_sum_rmse, pack_sum_average_peak_score(self.plan_size), - ], - evals=[(dtrain, 'tr')], - maximize=False, - verbose_eval=self.verbose_eval)]) - - def predict(self, task, states): - features = get_per_stmt_features_from_states(states, task) - if self.bst is not None and len(self.inputs) > self.num_warmup_sample: - dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) - raw_preds = self.bst.predict(dtest) - ret = pack_sum_predict_throughput(raw_preds, pack_ids) - else: - ret = np.random.uniform(0, 1, (len(states),)) - - # Predict 0 for invalid states that failed to be lowered. - for idx, feature in enumerate(features): - if feature.min() == feature.max() == 0: - ret[idx] = float('-inf') - - return ret - - def predict_stages(self, task, states): - # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) - features = get_per_stmt_features_from_states(states, task) - if self.bst is not None and len(self.inputs) > self.num_warmup_sample: - dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) - raw_preds = self.bst.predict(dtest) - breakdown = pack_sum_predict_throughput(raw_preds, pack_ids) - stage_scores = [[] for _ in range(len(states))] - for pred, pack_id in zip(raw_preds, pack_ids): - stage_scores[pack_id].append(pred) - for idx, stage_score in enumerate(stage_scores): - breakdown = np.append(breakdown, len(stage_score)) - breakdown = np.concatenate((breakdown, -np.array(stage_score))) - else: - breakdown = np.concatenate( - (np.random.uniform(0, 1, (len(states), )), np.zeros(len(states), ))) - - # Predict 0 for invalid states that failed to be lowered. - for idx, feature in enumerate(features): - if feature.min() == feature.max() == 0: - breakdown[idx] = float('-inf') - - return breakdown - - def load_log_file(self, file_name, n_lines=-1): - inputs, results = LogReader(file_name).read_lines(n_lines) - logger.info("XGBModel: Loaded %s lines of history log from %s", len(inputs), file_name) - self.update(inputs, results) - - def save(self, file_name: str): - self.bst.save_model(file_name) - - def load(self, file_name: str): - if self.bst is None: - self.bst = xgb.Booster(self.xgb_params) - self.bst.load_model(file_name) - self.num_warmup_sample = -1 - - -def pack_sum_xgbmatrix_for_prediction(xs): - x_flatten = [] - pack_ids = [] - - for ct, x in enumerate(xs): - for row in x: - x_flatten.append(row) - pack_ids.append(ct) - - return xgb.DMatrix(np.array(x_flatten)), pack_ids - - -def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): - if gids is not None: - # sort by group - indices = gids.argsort() - xs, ys = xs[indices], ys[indices] - group_sizes = np.bincount(gids) - if weights is not None: - weights = weights[indices] - else: - # assume it has only one group - group_sizes = [len(xs)] - - x_flatten = [] - y_flatten = [] - weights_flatten = [] - pack_ids = [] - - if weights is not None: - for ct, (x, y, w) in enumerate(zip(xs, ys, weights)): - for row in x: - x_flatten.append(row) - y_flatten.append(y) - weights_flatten.append(w) - pack_ids.append(ct) - else: - for ct, (x, y) in enumerate(zip(xs, ys)): - for row in x: - x_flatten.append(row) - y_flatten.append(y) - pack_ids.append(ct) - - ret = xgb.DMatrix(np.array(x_flatten), y_flatten) - if weights is not None: - ret.set_weight(weights_flatten) - dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) - dmatrix_context.put('group_sizes', ret, group_sizes) - return ret - -LOSS_TYPE = 3 - -# Type 0 -# The model predicts cost. Use square error of throughput as loss -# loss = 1/2 * (1 / sum(x_i) - y) ^ 2 -# -# Type 1 -# The model predicts cost. Use square error of cost as loss -# loss = 1/2 * (sum(x_i) - 1 / y) ^ 2 -# -# Type 2 -# The model predicts throughput. Use square error of throughput as loss. -# loss = 1/2 * (1 / sum(1 / x_i) - y) ^ 2 -# -# Type 3 -# The model predicts throughput. Use square error of throughput as loss. -# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) -# loss = 1/2 * (-sum(x_i) - y) ^ 2 -# -# Type 4 -# The model predicts throughput. Use square error of throughput as loss. -# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) -# Also add a sigmoid to force the prediction to be within the range of (0, 1) -# loss = 1/2 * (sigmoid(-sum(x_i)) - y) ^ 2 -# - -def pack_sum_predict_throughput(raw_preds, pack_ids): - if LOSS_TYPE == 0: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return 1 / sum_pred - elif LOSS_TYPE == 1: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return 1 / sum_pred - elif LOSS_TYPE == 2: - sum_inverse_preds = np.bincount(pack_ids, weights=1 / raw_preds) - return 1 / sum_inverse_preds - elif LOSS_TYPE == 3: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return - sum_pred # pylint: disable=invalid-unary-operand-type - elif LOSS_TYPE == 4: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return 1 / (1 + np.exp(sum_pred)) - else: - raise ValueError("Invalid loss type: " + LOSS_TYPE) - -def pack_sum_square_error(preds, dtrain): - pack_ids = dmatrix_context.get("pack_ids", dtrain) - weight = dtrain.get_weight() - - if LOSS_TYPE == 0: - sum_pred = np.bincount(pack_ids, weights=preds) - x = sum_pred[pack_ids] - y = dtrain.get_label() - gradient = (x * y - 1) / np.power(x, 3) - hessian = (3 - 2 * x * y) / np.power(x, 4) - elif LOSS_TYPE == 1: - sum_pred = np.bincount(pack_ids, weights=preds) - x = sum_pred[pack_ids] - y = dtrain.get_label() - gradient = x - 1 / np.minimum(y, 1e6) - hessian = np.ones_like(gradient) - elif LOSS_TYPE == 2: - sum_inverse_preds = np.bincount(pack_ids, weights=1 / preds)[pack_ids] - y = dtrain.get_label() - gradient = (1 / sum_inverse_preds - y) / (np.power(preds * sum_inverse_preds, 2)) - hessian = (2 * preds * y * np.power(sum_inverse_preds, 2) - 2 * y * sum_inverse_preds - 2 * preds * sum_inverse_preds + 3) / (np.power(preds * sum_inverse_preds, 4)) - elif LOSS_TYPE == 3: - sum_pred = np.bincount(pack_ids, weights=preds) - x = sum_pred[pack_ids] - y = dtrain.get_label() - gradient = x + y - hessian = np.ones_like(gradient) - elif LOSS_TYPE == 4: - sum_pred = np.bincount(pack_ids, weights=preds) - exp_x = np.exp(sum_pred[pack_ids]) - exp_2x = np.power(exp_x, 2) - y = dtrain.get_label() - gradient = exp_x * (exp_x * y + y - 1) / np.power(exp_x + 1, 3) - hessian = exp_x * (-exp_2x * y + 2 * exp_x + y - 1) / np.power(exp_x + 1, 4) - else: - raise ValueError("Invalid loss type: " + LOSS_TYPE) - - if len(weight) == 0: - return gradient, hessian - else: - return gradient * weight, hessian * weight - -def pack_sum_rmse(raw_preds, dtrain): - pack_ids = dmatrix_context.get("pack_ids", dtrain) - preds = pack_sum_predict_throughput(raw_preds, pack_ids)[pack_ids] - return 'p-rmse', np.sqrt(np.mean(np.square((preds - dtrain.get_label())))) - -def pack_sum_average_peak_score(N): - """Evaluate pack sum average peak score for xgb""" - - def feval(preds, labels): - group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) - pack_ids = dmatrix_context.get("pack_ids", labels) - - preds = pack_sum_predict_throughput(preds, pack_ids) - labels = (np.bincount(pack_ids, weights=labels.get_label()) - / np.unique(pack_ids, return_counts=True)[1]) - - scores = [] - offset = 0 - for size in group_sizes: - preds_group = preds[offset:offset + size] - labels_group = labels[offset:offset + size] - offset += size - - trials = np.argsort(preds_group)[::-1][:N] - trial_scores = labels_group[trials] - curve = max_curve(trial_scores) / np.max(labels_group) - scores.append(np.mean(curve)) - return "a-peak@%d" % N, np.mean(scores) - return feval - -def pack_sum_average_recall_score(N): - """Evaluate average recall score for xgb""" - - def feval(preds, labels): - group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) - pack_ids = dmatrix_context.get("pack_ids", labels) - - preds = pack_sum_predict_throughput(preds, pack_ids) - labels = (np.bincount(pack_ids, weights=labels.get_label()) - / np.unique(pack_ids, return_counts=True)[1]) - - scores = [] - offset = 0 - for size in group_sizes: - preds_group = preds[offset:offset + size] - labels_group = labels[offset:offset + size] - offset += size - - trials = np.argsort(preds_group)[::-1] - ranks = get_rank(labels_group[trials])[:N] - curve = recall_curve(ranks) - scores.append(np.mean(curve)) - return "a-recall@%d" % N, np.mean(scores) - return feval - - -def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, - maximize=False, verbose_eval=True, skip_every=2): - """Callback function for xgboost to support multiple custom evaluation functions""" - from xgboost.core import EarlyStopException - from xgboost.callback import _fmt_metric - from xgboost.training import aggcv - - state = {} - metric_shortname = metric.split("-")[1] - - def init(env): - """internal function""" - bst = env.model - - state['maximize_score'] = maximize - state['best_iteration'] = 0 - if maximize: - state['best_score'] = float('-inf') - else: - state['best_score'] = float('inf') - - if bst is not None: - if bst.attr('best_score') is not None: - state['best_score'] = float(bst.attr('best_score')) - state['best_iteration'] = int(bst.attr('best_iteration')) - state['best_msg'] = bst.attr('best_msg') - else: - bst.set_attr(best_iteration=str(state['best_iteration'])) - bst.set_attr(best_score=str(state['best_score'])) - else: - assert env.cvfolds is not None - - def callback(env): - """internal function""" - if not state: - init(env) - - bst = env.model - i = env.iteration - cvfolds = env.cvfolds - - res_dict = {} - - if i % skip_every == 1: - return - - ##### evaluation ##### - if cvfolds is not None: - for feval in fevals: - tmp = aggcv([f.eval(i, feval) for f in cvfolds]) - for k, mean, std in tmp: - res_dict[k] = [mean, std] - else: - for feval in fevals: - bst_eval = bst.eval_set(evals, i, feval) - res = [x.split(':') for x in bst_eval.split()] - for kv in res[1:]: - res_dict[kv[0]] = [float(kv[1])] - - eval_res = [] - keys = list(res_dict.keys()) - keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) - for key in keys: - v = res_dict[key] - eval_res.append([key] + v) - - ##### print eval result ##### - if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: - infos = ["XGB iter: %3d" % i] - for item in eval_res: - if 'null' in item[0]: - continue - infos.append("%s: %.6f" % (item[0], item[1])) - - logger.debug("\t".join(infos)) - if log_file: - with open(log_file, "a") as fout: - fout.write("\t".join(infos) + '\n') - - ##### choose score and do early stopping ##### - score = None - for item in eval_res: - if item[0] == metric: - score = item[1] - break - assert score is not None - - best_score = state['best_score'] - best_iteration = state['best_iteration'] - maximize_score = state['maximize_score'] - if (maximize_score and score > best_score) or \ - (not maximize_score and score < best_score): - msg = '[%d] %s' % ( - env.iteration, - '\t'.join([_fmt_metric(x) for x in eval_res])) - state['best_msg'] = msg - state['best_score'] = score - state['best_iteration'] = env.iteration - # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr(best_score=str(state['best_score']), - best_iteration=str(state['best_iteration']), - best_msg=state['best_msg']) - elif env.iteration - best_iteration >= stopping_rounds: - best_msg = state['best_msg'] - if verbose_eval and env.rank == 0: - logger.debug("XGB stopped. Best iteration: %s ", best_msg) - raise EarlyStopException(best_iteration) - - return callback diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py deleted file mode 100644 index 3a5dc4e9e206..000000000000 --- a/python/tvm/ansor/dispatcher.py +++ /dev/null @@ -1,299 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The global context that dispatches best configurations to workloads -""" -# pylint: disable=invalid-name - -from __future__ import absolute_import as _abs - -import logging - -import numpy as np - -from tvm.tir.expr import FloatImm - -logger = logging.getLogger('auto_scheduler') - - -class DispatchContext(object): - """ - Base class of dispatch context. - """ - current = None - - def __init__(self): - self._old_ctx = DispatchContext.current - - def query(self, target, workload): - """ - Query the context to get the specific config for a workload. - If cannot find the result inside this context, this function will query it - from the upper contexts. - - Parameters - ---------- - target: Target - The current target - workload : str - The current workload - - Returns - ------- - cfg : State - The schedule configuration for the workload - """ - ret = self._query_inside(target, workload) - return ret - - def update(self, target, workload, cfg): - """ - Update the config for a workload - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - cfg : State - The schedule configuration for the workload - """ - raise NotImplementedError() - - def _query_inside(self, target, workload): - """ - Query the context to get the specific config for a workload. - This function only query config inside this context. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - - Returns - ------- - cfg : State or str - The schedule configuration for the workload - """ - raise NotImplementedError() - - def __enter__(self): - self._old_ctx = DispatchContext.current - DispatchContext.current = self - return self - - def __exit__(self, ptype, value, trace): - DispatchContext.current = self._old_ctx - - -class ApplyConfig(DispatchContext): - """Apply a deterministic config for all queries. - - Parameters - ---------- - config : State - The schedule configuration - """ - def __init__(self, config): - super(ApplyConfig, self).__init__() - self._config = config - self.workload = None - - def _query_inside(self, target, workload): - """Override query""" - self.workload = workload - return self._config - - def update(self, target, workload, cfg): - """Override update""" - self.workload = workload - self._config = cfg - - -class ApplyHistoryBest(DispatchContext): - """ - Apply the history best config - - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - n_lines: int (optional) - if it is not None, only load the first `n_lines` lines of log - """ - def __init__(self, records, n_lines=None): - super(ApplyHistoryBest, self).__init__() - - self.best_by_targetkey = {} - self.best_by_model = {} - self._best_user_defined = {} - - if records: - self.load(records, n_lines) - - def load(self, records, n_lines=None): - """Load records to this dispatch context - - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - n_lines: int (optional) - if it is not None, only load the first `n_lines` lines of log - """ - from pathlib import Path - from . import load_from_file - - if isinstance(records, Path): - records = str(records) - - if isinstance(records, str): - records = load_from_file(records) - if not records: - return - - best_by_targetkey = self.best_by_targetkey - best_by_model = self.best_by_model - - counter = 0 - for inp, res in records: - if n_lines is not None and counter >= n_lines: - break - counter += 1 - if res.error_no != 0: - continue - - # use target keys in tvm target system as key to build best map - for k in inp.task.target.keys: - key = (k, inp.task.workload_key) - if key not in best_by_targetkey: - best_by_targetkey[key] = (inp, res) - else: - _, other_res = best_by_targetkey[key] - other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] - costs = [x.value for x in res.costs if isinstance(x, FloatImm)] - if np.mean(other_costs) > np.mean(costs): - best_by_targetkey[key] = (inp, res) - - # use model as key to build best map - key = (inp.task.target.model, inp.task.workload_key) - if key not in best_by_model: - if inp.task.target.model != 'unknown': - best_by_model[key] = (inp, res) - else: - _, other_res = best_by_model[key] - other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] - costs = [x.value for x in res.costs if isinstance(x, FloatImm)] - if np.mean(other_costs) > np.mean(costs): - best_by_model[key] = (inp, res) - - logger.debug("Finish loading %d records", counter) - - def _query_inside(self, target, workload): - if target is None: - raise RuntimeError("Need a target context to find the history best. " - "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" - " above the dispatcher call. So does other target. ") - - # first try matching by model - key = (target.model, workload) - if key in self._best_user_defined: - return self._best_user_defined[key] - if key in self.best_by_model: - return self.best_by_model[key][0].state - - # then try matching by target key - for k in target.keys: - key = (k, workload) - if key in self._best_user_defined: - return self._best_user_defined[key] - if key in self.best_by_targetkey: - return self.best_by_targetkey[key][0].state - - return None - - def update(self, target, workload, state): - model = target.model - key = (model, workload) - self._best_user_defined[key] = state - - for k in target.keys: - key = (k, workload) - self._best_user_defined[key] = state - - -class FallbackContext(DispatchContext): - """ - A fallback dispatch context. - This is used as the root context. - """ - - def __init__(self): - super(FallbackContext, self).__init__() - self.memory = {} - self.silent = False - - # a set to prevent print duplicated message - self.messages = set() - - def _query_inside(self, target, workload): - key = (str(target), workload) - if key in self.memory: - return self.memory[key] - - if not self.silent: - msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\ - "is used, which may bring great performance regression." % (target, workload) - if msg not in self.messages: - self.messages.add(msg) - logger.warning(msg) - cfg = None - - # cache this config to avoid duplicated warning message - self.memory[key] = cfg - return cfg - - def clear_cache(self, target, workload): - """Clear fallback cache. Pass the same argument as _query_inside to this function - to clean the cache. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - """ - key = (str(target), workload) - if key in self.memory: - del self.memory[key] - - def update(self, target, workload, cfg): - key = (str(target), workload) - self.memory[key] = cfg - - -DispatchContext.current = FallbackContext() diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py deleted file mode 100644 index 56e76e26ee4f..000000000000 --- a/python/tvm/ansor/env.py +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" The scope to store global variables in ansor """ - - -class AutoschedulerGlobalScope(object): - def __init__(self): - self.topi_in_compute_rewrite_mode = False - -GLOBAL_SCOPE = AutoschedulerGlobalScope() diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py deleted file mode 100644 index fa1b2cb07dcc..000000000000 --- a/python/tvm/ansor/feature.py +++ /dev/null @@ -1,150 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""" -Python API for Feature extraction. -""" - -from typing import List, Tuple -import struct -import numpy as np - -from .loop_state import State, StateObject -from .measure import MeasureInput, MeasureResult -from . import _ffi_api - - -# Maximum number of buffers for one statement to extract feature for -DEFAULT_MAX_N_BUFS = 5 - -# The length of the feature vector -DEFAULT_FEATURE_VEC_LEN = 164 - - -def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Unpack the encoded feature (in byte array format) of from c++""" - size_of_int = 4 - size_of_float = 4 - - # The format for n records is: - # { - # int n; - # int[n+2] sizes - - # float[sizes[0]] feature for record 1 - # float[sizes[1]] feature for record 2 - # ... feature for record i... - # float[sizes[n-1]] feature for record n - - # float[sizes[n]] normalized throughput for n records - # int[sizes[n+1]] task id for n records - # } - - vec_len = DEFAULT_FEATURE_VEC_LEN - - # unpack sizes - offset = 0 - n = struct.unpack_from("1i", byte_arr, offset=offset)[0] - offset += size_of_int - - sizes = struct.unpack_from("%di" % (n+2), byte_arr, offset=offset) - offset += size_of_int * (n+2) - - # unpack features - features = [] - for size in sizes[:-2]: - row = [] - - # Now we need to unpack the feature for multiple statements. - # The format is: - # { - # int n_stmts - # float[n_stmt][vec_len] feature_vecs - # } - # where vec_len can be calculated by `(size - 1) / n_stmts` - - if size == 0: - # failed during lowering - features.append(np.zeros((1, vec_len))) - else: - n_stmts = struct.unpack_from("f", byte_arr, offset=offset) - offset += size_of_float - - n_stmts = int(n_stmts[0] + 0.5) - tmp_vec_len = (size - 1) // n_stmts - assert tmp_vec_len == vec_len, "The lenght of feature vector is wrong. " \ - "Expected %d but got %d." % (vec_len, tmp_vec_len) - assert (size - 1) % n_stmts == 0 - for _ in range(n_stmts): - x = struct.unpack_from("%df" % vec_len, byte_arr, offset=offset) - offset += vec_len * size_of_float - row.append(x) - - features.append(np.array(row)) - - # unpack normalized_throughputs - m = sizes[-2] - normalized_throughputs = struct.unpack_from("%df" % m, byte_arr, offset=offset) - offset += m * size_of_int - - # unpack task_ids - m = sizes[-1] - task_ids = struct.unpack_from("%di" % m, byte_arr, offset=offset) - offset += m * size_of_int - - assert offset == len(byte_arr), "%d vs %d" % (offset, len(byte_arr)) - return np.array(features), np.array(normalized_throughputs), np.array(task_ids) - - -def get_per_stmt_features_from_file(filename: str, - n_lines: int, - max_n_bufs: int = None) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Get per_stmt features from a log file""" - byte_arr = _ffi_api.GetPerStmtFeaturesFromFile( - filename, n_lines, max_n_bufs or DEFAULT_MAX_N_BUFS) - return unpack_feature(byte_arr) - - -def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], - results: List[MeasureResult], - skip_first_n_feature_extraction: int = 0, - max_n_bufs: int = None) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Get per_stmt features from measurement pairs""" - byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( - inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS) - return unpack_feature(byte_arr) - - -def get_per_stmt_features_from_states(states, - task: "SearchTask", - max_n_bufs: int = None) -> List[np.ndarray]: - """Get per_stmt features from states""" - if isinstance(states[0], State): - state_objects = [s.state_object for s in states] - elif isinstance(states[0], StateObject): - state_objects = states - byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( - state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS) - return unpack_feature(byte_arr)[0] - - -def get_per_stmt_feature_names(max_n_bufs: int = None) -> List[str]: - """Get names for the elements in the flatten feature vector""" - return [x for x in - _ffi_api.GetPerStmtFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)] diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 7aa5de0e9c1d..bf81311ed664 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -152,61 +152,6 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): self._clear_cache() return res - def follow_split(self, stage_id, iterator, src_step_id, n_split): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to split - iterator : Iterator - The iterator to split - src_step_id : int - The index of the split step to follow in the history - n_split : int - The number of split level - - Returns - ------- - res_its : List[Iterator] - The splitted new Iterators - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, - src_step_id, n_split) - self._clear_cache() - return res - - def follow_fused_split(self, stage_id, iterator, src_step_ids, level, - factor_or_nparts): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to split - iterator : Iterator - The iterator to split - src_step_ids : List[int] - The indices of the split steps to follow in the history - level : int - Use the length in this split level - factor_or_nparts : bool - True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner - - Returns - ------- - res_its : List[Iterator] - The splitted new Iterators - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, - iterator, src_step_ids, level, - factor_or_nparts) - self._clear_cache() - return res - def fuse(self, stage_id, iters): """ Parameters @@ -227,270 +172,6 @@ def fuse(self, stage_id, iters): self._clear_cache() return res - def vectorize(self, stage_id, iterator): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to vectorize - iterator : Iterator - The iterator to be vectorized - - Returns - ------- - res_it : Iterator - The vectorized Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) - self._clear_cache() - return res - - def parallel(self, stage_id, iterator): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to parallel - iterator : Iterator - The iterator to be parallelized - - Returns - ------- - res_it : Iterator - The parallelized Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) - self._clear_cache() - return res - - def unroll(self, stage_id, iterator, max_unroll=-1): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to unroll - iterator : Iterator - The iterator to be unrolled - max_unroll: int - The maximum length of the iterator that can be unrolled - - Returns - ------- - res_it : Iterator - The unrolled Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, - max_unroll) - self._clear_cache() - return res - - def bind_thread(self, stage_id, iterator, thread_name): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to bind - iterator : Iterator - The iterator to be bound - thread_name : str - The name of the thread (e.g. "blockIdx.x", "threadIdx.y", "vthread") - - Returns - ------- - res_it : Iterator - The bound Iterator - """ - trans_table = { - "vthread": 4, - "blockIdx.x": 5, - "threadIdx.x": 6, - "blockIdx.y": 7, - "threadIdx.y": 8, - } - thread_id = trans_table[thread_name] - - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, - thread_id) - self._clear_cache() - return res - - def compute_at(self, stage_id, target_stage_id, target_iter): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of source stage - target_stage_id : Union[int, Operation, Tensor] - The index of the target stage of compute_at - target_iter : Iterator - The target Iterator of compute_at - """ - stage_id = self._resolve_stage_id(stage_id) - target_stage_id = self._resolve_stage_id(target_stage_id) - - self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, - target_stage_id, target_iter) - self._clear_cache() - - def compute_root(self, stage_id): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to compute root - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) - self._clear_cache() - - def compute_inline(self, stage_id): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to compute inline - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) - self._clear_cache() - - def cache_read(self, stage_id, scope_name, reader_stage_ids): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do cache_read - scope_name : str - reader_stage_ids : List[int] - - Returns - ------- - new_stage_id : int - The added staged id - """ - stage_id = self._resolve_stage_id(stage_id) - - if isinstance(reader_stage_ids, list): - tmp_list = [] - for reader_stage_id in reader_stage_ids: - tmp_list.append(self._resolve_stage_id(reader_stage_id)) - reader_stage_ids = tmp_list - else: - raise ValueError("reader_stage_ids must be list of Tensor or int") - - self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, - scope_name, reader_stage_ids, - self.compute_dag) - return self._insert_new_stage(new_stage_id) - - def cache_write(self, stage_id, scope_name): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do cache read - scope_name : str - - Returns - ------- - new_stage_id : int - The added staged id - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, - scope_name, self.compute_dag) - return self._insert_new_stage(new_stage_id) - - def pragma(self, stage_id, iterator, pragma_type): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to add pragma - iterator : Iterator - The iterator to add pragma - pragma_type : str - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, - pragma_type) - self._clear_cache() - - def rfactor(self, stage_id, iterator, factor_iter_id): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do reduction factor - iterator : Iterator - factor_iter_id : int - - Returns - ------- - new_stage_id : int - The added staged id - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, - iterator, factor_iter_id, - self.compute_dag) - return self._insert_new_stage(new_stage_id) - - def storage_align(self, stage_id, iterator, factor, offset): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do storage align - iterator : Iterator - factor : int - offset : int - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, - factor, offset) - self._clear_cache() - - def tensorize(self, stage_id, iterator, ti_func_name): - """ The `ti_func_name` corresponds to a global registered funcion - that returns a Tensorintrin - - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do storage align - iterator : Iterator - The iterator to be tensorized - ti_func_name : str - Tensorize intrinsic function name - - Returns - ------- - res_it : Iterator - The tensorized Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateTensorize(self.state_object, - stage_id, iterator, - ti_func_name) - self._clear_cache() - return res - def _resolve_stage_id(self, stage_id): if isinstance(stage_id, Operation): return self.stage_id_map[stage_id] diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index be7d69e5ed3a..46c3e3aabd5d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -42,7 +42,6 @@ from . import _ffi_api from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ check_remote -from .compute_dag import LayoutRewriteLevel LOGGER = logging.getLogger('ansor') @@ -331,7 +330,7 @@ def timed_func(): try: sch, args = task.compute_dag.apply_steps_from_state( - inp.state, LayoutRewriteLevel.BOTH_REWRITE) + inp.state) except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR error_msg = make_error_msg() diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py deleted file mode 100644 index f2873f8c72fd..000000000000 --- a/python/tvm/ansor/relay_integration.py +++ /dev/null @@ -1,241 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-variable,invalid-name - -""" -Integrate ansor into relay. It implements the following items: -1. Extract search tasks from a relay program -2. Provide auto-scheduling for all TOPI compute functions -""" -import os -import json -import threading - -import tvm -from tvm import te, transform -from tvm.te.tensor import PlaceholderOp, ComputeOp -from .dispatcher import DispatchContext -from .workload_registry import register_workload_bufs, compute_dag_hash -from .compute_dag import ComputeDAG, LayoutRewriteLevel -from .env import GLOBAL_SCOPE - -def call_all_topi_funcs(mod, target, params, target_host=None): - """Call all TOPI compute + schedule to extract tasks in a relay program""" - # pylint: disable=import-outside-toplevel - from tvm import relay - - with transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - bld_mod = relay.build_module.BuildModule() - bld_mod.call_all_topi_funcs(mod, target=target, params=params, target_host=target_host) - -def extract_from_program(mod, params, target, target_host=None): - """ Extract tuning tasks from a relay program. - - This function is the single program version of extract_from_multiple_program. - - Parameters - ---------- - mod : relay.Module - The module to extract. - params: dict of str to numpy array - The associated parameters of the program - ops: List of relay op - List of relay ops to be tuned - target: tvm.target.Target - The compilation target - target_host: tvm.target.Target - The host compilation target - - Returns - ------- - workloads: Array of Tuple(wkl_key, target) - """ - return extract_from_multiple_program([mod], [params], target, target_host) - -def extract_from_multiple_program(mods, params, target, target_host=None): - """ Extract tuning tasks from multiple relay programs. - - Parameters - ---------- - mods : List of relay.Module - The modules to extract. - params: List of dict of str to numpy array - The associated parameters of the programs - ops: List of relay op - List of relay ops to be tuned - target: tvm.target.Target - The compilation target - target_host: tvm.target.Target - The host compilation target - - Returns - ------- - workloads: Array of Tuple(wkl_key, target) - """ - # pylint: disable=import-outside-toplevel - from tvm import relay - - env = TracingEnvironment(TracingMode.EXTRACT_TASK) - with env: - # run compiler to collect all TOPI calls during compilation - for mod, param in zip(mods, params): - # wrap build call in a new thread to avoid the conflict - # between python's multiprocessing and tvm's thread pool - build_thread = threading.Thread(target=call_all_topi_funcs, - args=(mod, target, param, target_host)) - build_thread.start() - build_thread.join() - relay.backend.compile_engine.get().clear() - - # create tasks for target - wkl_keys = [] - wkl_weights = [] - for wkl_key, wkl_weight in env.wkl_key_collection.items(): - wkl_keys.append(wkl_key) - wkl_weights.append(wkl_weight) - - return wkl_keys, wkl_weights - - -def prepare_layout_rewrite(mod, params, target): - """ - Prepare for kernel layout rewrite. This function will write layout infos to a global static - variable. - Then these layout info will be used by a relay pass `kernel_layout_transform`. - """ - # pylint: disable=import-outside-toplevel - from tvm import relay - - env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) - with env: - # wrap build call in a new thread to avoid the conflict - # between python's multiprocessing and tvm's thread pool - build_thread = threading.Thread(target=call_all_topi_funcs, - args=(mod, target, params)) - build_thread.start() - build_thread.join() - relay.backend.compile_engine.get().clear() - - if env.layout_rewrite_success_ct > 0: - GLOBAL_SCOPE.topi_in_compute_rewrite_mode = True - -def finish_layout_rewrite(): - """Clear the global flag for layout rewrite""" - GLOBAL_SCOPE.topi_in_compute_rewrite_mode = False - - -class TracingMode: - """Two modes for tracing""" - EXTRACT_TASK = 0 # trace all topi calls to extract tasks - PREPARE_LAYOUT_REWRITE = 1 # trace all topi calls to prepare layout rewrite - -class TracingEnvironment: - """Global environment for tracing all topi function calls""" - current = None - - def __init__(self, tracing_mode): - self.tracing_mode = tracing_mode - self.relay_disable_build_cache = "false" - self.layout_rewrite_success_ct = 0 - self.wkl_key_collection = {} - - def __enter__(self): - self.relay_disable_build_cache = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" - TracingEnvironment.current = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache - TracingEnvironment.current = None - - def add_workload_key(self, key): - """Add the workload key of an Ansor search task - - Parameters - ---------- - key: str - """ - if key in self.wkl_key_collection: - self.wkl_key_collection[key] += 1 - else: - self.wkl_key_collection[key] = 1 - - -def traverse_to_get_io_tensors(outs): - """Traverse from a list of output tensors to get a whole computational DAG""" - layout_free_ops = [] - inputs = [] - - visited = set() - - def traverse(t): - if t in visited: - return - if isinstance(t.op, PlaceholderOp): - inputs.append(t) - elif isinstance(t.op, ComputeOp): - if "layout_free_placeholders" in t.op.attrs: - layout_free_ops.append(t.op) - for x in t.op.input_tensors: - traverse(x) - visited.add(t) - - for t in outs: - traverse(t) - - has_layout_free = (len(layout_free_ops) > 0) - return inputs + [t for t in outs], has_layout_free - - -def auto_schedule_topi(outs): - """ Use ansor to auto-schedule a topi compute declaration """ - io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_workload_bufs(io_tensors) - - env = TracingEnvironment.current - if env is None: # in the final build mode - state = DispatchContext.current.query(tvm.target.Target.current(), key) - if state is None: - return te.create_schedule([x.op for x in outs]) - - dag = ComputeDAG(io_tensors) - # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, - # Since kernel layout has already been rewritten in relay pass - schedule, _ = dag.apply_steps_from_state( - state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) - return schedule - if env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode - env.add_workload_key(key) - return te.create_schedule([x.op for x in outs]) - if env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: - # in prepare_layout_rewrite mode - if has_layout_free: - # Rewrite the DAG and update the transform history for - # the new dag in DispatchContext - dispatch_ctx = DispatchContext.current - tgt = tvm.target.Target.current() - state = dispatch_ctx.query(tgt, key) - assert state is not None - dag = ComputeDAG(outs) - new_dag = dag.rewrite_layout_from_state(state) - new_key = json.dumps((compute_dag_hash(new_dag),)) - dispatch_ctx.update(tgt, new_key, state) - if new_key != key: - env.layout_rewrite_success_ct += 1 - return te.create_schedule([x.op for x in outs]) - raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py deleted file mode 100644 index 5b916ed39769..000000000000 --- a/python/tvm/ansor/task_scheduler.py +++ /dev/null @@ -1,299 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""TaskScheduler that allocates the time resources when tuning multiple tasks together""" -from typing import List, Union, Callable -import time - -import numpy as np - -from .auto_schedule import SearchTask, SearchPolicy, SketchSearchPolicy, TuneOption -from .cost_model import RandomModel, XGBModel -from .measure import ProgramMeasurer -from .utils import array_mean, to_str_round - - -class TaskScheduler: - """Allocate the time resources when tuning multiple tasks together""" - def __init__(self, - tasks: List[SearchTask], - objective_func: Callable = None): - self.tasks = tasks - self.objective_func = objective_func or sum - - def compute_score(self, costs: List[float]) -> float: - return self.objective_func(costs) - - -def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], - num_measure_per_iter, load_model_file=None, load_log_file=None): - """ ... - """ - if search_policy == 'default': - search_policy = 'sketch.xgb' - - if isinstance(search_policy, str): - policy_type, model_type = search_policy.split('.') - if model_type == 'xgb': - cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measure_per_iter) - if load_model_file: - print("Load pretrained model...") - cost_model.load(load_model_file) - elif load_log_file: - cost_model.load_log_file(load_log_file) - elif model_type == 'random': - cost_model = RandomModel() - else: - raise ValueError("Invalid search policy: " + search_policy) - - if policy_type == 'sketch': - search_policies = [SketchSearchPolicy(cost_model) for _ in range(len(tasks))] - elif policy_type == 'limit-space': - search_policies = [SketchSearchPolicy(cost_model, - params={'cpu_multi_level_tiling_structure': 'SRS', - 'disable_change_compute_location': 1}) - for _ in range(len(tasks))] - elif policy_type == 'beam-search': - search_policies = [SketchSearchPolicy(cost_model, - params={'use_beam_search': 1}) - for _ in range(len(tasks))] - else: - raise ValueError("Invalid search policy: " + search_policy) - else: - # check type - assert isinstance(search_policy, (tuple, list)) - for item in search_policy: - assert isinstance(item, SearchPolicy) - search_policies = search_policy - - return search_policies - - -class SimpleTaskScheduler(TaskScheduler): - """The default task scheduler with several strategies - - Parameters - ---------- - tasks: List[SearchTask] - All workloads to tune - weights: List[float] - Weights of tasks (i.e. the number of occurrence of a task in the whole network) - strategy: str - The joint tuning strategy. - "sequential" : Tune tasks sequentially. Divide n_trials equally to every task. - "round-robin": Tune tasks in round robin order. - "gradient" : Tune tasks with gradient descent. - load_log_file: str - Load history log file to pre-train cost model - eps-random: float - Always allocate this percent of n_trials to select tasks randomly. - This is for encouraging exploration. - verbose: int - The level of verbosity. 0 means silent. - alpha: float - The parameter used for 'gradient' strategy - beta: float - The parameter used for 'gradient' strategy - backward_window_size: int - The parameter used for 'gradient' strategy - """ - def __init__(self, - tasks: List[SearchTask], - objective_func: Callable = None, - strategy: str = 'gradient', - load_log_file: str = None, - load_model_file: str = None, - eps_random: float = 0.05, - verbose: int = 1, - alpha: float = 0.2, - beta: float = 2, - gamma: float = 0.5, - backward_window_size: int = 3, - use_debug_measurement_simulator=None): - super().__init__(tasks, objective_func) - self.strategy = strategy - self.eps_random = eps_random - self.verbose = verbose - self.load_log_file = load_log_file - self.load_model_file = load_model_file - self.alpha = alpha - self.beta = beta - self.gamma = gamma - self.backward_window_size = backward_window_size - self.use_debug_measurement_simulator = use_debug_measurement_simulator - - assert self.strategy in ['round-robin', 'gradient'] - - self.task_cts = [] - self.task_costs_history = [] - self.best_costs = self.cur_score = None - self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None - self.num_measure_per_iter = None - self.dead_tasks = set() - self.sequential_now_task_idx = 0 - self.sequential_now_task_begin_ct = 0 - - def tune(self, tune_option: TuneOption, - search_policy: Union[str, List[SearchPolicy]] = 'default'): - """ Tune tasks. - - Notice: This method does not have return value, make sure to set `LogToFile` - measure callback in `tune_option`. - - Parameters - ---------- - tune_option: TuneOption - search_policy: Str or List[SearchPolicy] - """ - # init members - self.task_cts = [0 for _ in range(len(self.tasks))] - self.task_costs_history = [[] for _ in range(len(self.tasks))] - self.best_costs = 1e10 * np.ones(len(self.tasks)) - self.cur_score = self.compute_score(self.best_costs) - self.tune_option = tune_option - if self.use_debug_measurement_simulator is None: - self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, - tune_option.measure_callbacks, tune_option.verbose) - self.ct = 0 - self.tic = time.time() - # reset num_measure_per_iter to make sure every task is tuned at least once - self.num_measure_per_iter = min(tune_option.num_measure_per_iter, - tune_option.n_trials // len(self.tasks)) - self.search_policies = get_search_policies(search_policy, self.tasks, - self.num_measure_per_iter, - self.load_model_file, - self.load_log_file) - self.dead_tasks = set() - self.sequential_now_task_idx = 0 - self.sequential_now_task_begin_ct = 0 - - for i in range(len(self.tasks)): - search_policy = self.search_policies[i] - task = self.tasks[i] - search_policy.set_task(task) - search_policy.set_verbose(tune_option.verbose) - search_policy.run_callbacks(tune_option.pre_search_callbacks) - - # do a round robin first - if self.strategy != 'sequential': - for i in range(len(self.tasks)): - self.tune_task(i) - - # use the specific strategy to choose workload to tune - task_idx = -1 - while self.ct < tune_option.n_trials and len(self.dead_tasks) < len(self.tasks): - if self.strategy == 'sequential': - allocated_total_ct = ((tune_option.n_trials - self.sequential_now_task_begin_ct) - / (len(self.tasks) - self.sequential_now_task_idx)) - used_ct = self.ct - self.sequential_now_task_begin_ct - - if self.sequential_now_task_idx in self.dead_tasks or used_ct >= allocated_total_ct: - self.sequential_now_task_idx += 1 - self.sequential_now_task_begin_ct = self.ct - task_idx = self.sequential_now_task_idx - if task_idx >= len(self.tasks): - break - elif self.strategy == 'round-robin': - task_idx = (task_idx + 1) % len(self.tasks) - while task_idx in self.dead_tasks: - task_idx = (task_idx + 1) % len(self.tasks) - elif self.strategy == 'gradient': - gradients = [] - for i in range(len(self.tasks)): - if i in self.dead_tasks: - gradients.append(0) - continue - - # compute gradient from chain rule : (delta f / delta g_i) - delta = 1e-7 - new_costs = list(self.best_costs) - new_costs[i] -= delta - chain_grad = (self.compute_score(self.best_costs) - self.compute_score(new_costs)) / delta - - # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t) - if self.task_cts[i] - 1 - self.backward_window_size >= 0: - backward_grad = (self.task_costs_history[i][self.task_cts[i] - 1] - - self.task_costs_history[i][self.task_cts[i] - 1 - self.backward_window_size]) \ - / self.backward_window_size - else: - backward_grad = 0 - - # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t) - g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i]) - # todo(lmzheng): this needs adding attribute to topi.compute for similarity check - g_next_2 = self.beta * 1e20 - g_next = min(g_next_1, g_next_2) - forward_grad = g_next - self.best_costs[i] - - # combine all grads - grad = chain_grad * (self.alpha * backward_grad + (1 - self.alpha) * forward_grad) - assert grad <= 0 - gradients.append(grad) - - if max(gradients) == min(gradients): - task_idx = np.random.choice(len(gradients)) - else: - task_idx = np.argmin(gradients) - else: - raise ValueError("Invalid strategy: " + self.strategy) - - if self.verbose >= 1: - print("Next tuning task: %d" % task_idx) - self.tune_task(task_idx) - - def tune_task(self, task_idx): - """ ... - """ - if self.use_debug_measurement_simulator is not None: - measure_inputs, measure_results = \ - self.use_debug_measurement_simulator.get_next_batch( - self.tasks[task_idx], - self.num_measure_per_iter, - ) - else: - measure_inputs, measure_results = \ - self.search_policies[task_idx].continue_search( - self.tasks[task_idx], - self.num_measure_per_iter, - self.tune_option.verbose, - self.measurer) - - for inp, res in zip(measure_inputs, measure_results): - cost = array_mean(res.costs) - if cost < self.best_costs[task_idx]: - self.best_costs[task_idx] = cost - - if len(measure_inputs) == 0: - self.dead_tasks.add(task_idx) - - self.task_cts[task_idx] += 1 - self.task_costs_history[task_idx].append(self.best_costs[task_idx]) - - self.ct += len(measure_inputs) - self.cur_score = self.compute_score(self.best_costs) - - if self.verbose >= 1: - print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + - "best_costs (ms): %s\ttask_ct: %s") % - (self.ct, self.cur_score * 1e3, time.time() - self.tic, - to_str_round(self.best_costs * 1e3, decimal=3), - self.task_cts)) - - def remove_dead_task(self, prob): - for idx in self.dead_tasks: - prob[idx] = 0 - return prob / prob.sum() diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index b6bedb411540..8e6698e4a164 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -18,7 +18,6 @@ """Backend code generation engine.""" from __future__ import absolute_import -import os import logging import numpy as np import tvm @@ -142,6 +141,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): ret.append(impl) return ret + def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True): """Select the best implementation from the op strategy. @@ -179,9 +179,6 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor]) The best op implementation and the corresponding output tensors. """ - if os.environ.get('TVM_USE_AUTOTVM', 'false') == 'false': - use_autotvm = False - all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) best_plevel_impl = None diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d1a39ceb630e..30c5971e32b9 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -72,7 +72,6 @@ def __init__(self): self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] - self._call_all_topi_funcs = self.mod["call_all_topi_funcs"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -161,12 +160,6 @@ def optimize(self, mod, target=None, params=None): return mod, params - def call_all_topi_funcs(self, mod, target=None, target_host=None, params=None): - """Call all topi compute and schedule used in a relay function""" - target = _update_target(target) - if params: - self._set_params(params) - self._call_all_topi_funcs(mod, target, target_host) def _set_params(self, params): self._set_params_func(_convert_param_map(params)) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 41bd10cabe3e..d104c1b1c2f8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -74,8 +74,6 @@ def compute_strided_set(attrs, inputs, output_type): # layout_transform _reg.register_injective_schedule("layout_transform") _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) -_reg.register_injective_schedule("kernel_layout_transform") -_reg.register_pattern("kernel_layout_transform", OpPattern.INJECTIVE) # argwhere @_reg.register_compute("argwhere") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 58b9269a4c48..486d63c36ff0 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -261,9 +261,6 @@ class ClipAttrs(Attrs): class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" -@tvm._ffi.register_object("relay.attrs.KernelLayoutTransformAttrs") -class KernelLayoutTransformAttrs(Attrs): - """Attributes for transform.kernel_layout_transform""" @tvm._ffi.register_object("relay.attrs.ShapeOfAttrs") class ShapeOfAttrs(Attrs): diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3453b089f373..b02db416bdc8 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -16,15 +16,14 @@ # under the License. """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import +import logging +import re +import topi from tvm.te import SpecializedCondition -from tvm import ansor from .generic import * from .. import op as _op -# Set the priority level to use the Ansor auto-scheduler -ansor_plevel = 11 - logger = logging.getLogger('strategy') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") @@ -40,7 +39,7 @@ def schedule_injective_cpu(attrs, outs, target): def schedule_reduce_cpu(attrs, outs, target): """schedule reduction ops for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_reduce(outs) @schedule_concatenate.register("cpu") def schedule_concatenate_cpu(attrs, outs, target): @@ -52,13 +51,13 @@ def schedule_concatenate_cpu(attrs, outs, target): def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_pool(outs, attrs.layout) @schedule_adaptive_pool.register("cpu") def schedule_adaptive_pool_cpu(attrs, outs, target): """schedule adaptive pooling ops for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_adaptive_pool(outs) @softmax_strategy.register("cpu") def softmax_strategy_cpu(attrs, inputs, out_type, target): @@ -66,15 +65,15 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_softmax), + name="softmax.x86") return strategy @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_softmax(outs) @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): @@ -106,18 +105,18 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - #logger.warning("For x86 target, NCHW layout is recommended for conv2d.") + logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), + name="conv2d_nhwc.x86") elif layout == "HWCN": assert kernel_layout == "HWIO" - #logger.warning("conv2d HWCN layout is not optimized for x86.") + logger.warning("conv2d HWCN layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), + name="conv2d_hwcn.generic") else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -144,8 +143,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.generic") else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) else: # group_conv2d @@ -154,8 +153,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("group_conv2d is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), + name="group_conv2d_nchw.generic") else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -232,8 +231,8 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): name="conv3d_ncdhw.x86") elif layout == "NDHWC": strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), + name="conv3d_ndhwc.x86") else: raise ValueError("Not support this layout {} yet".format(layout)) return strategy @@ -252,8 +251,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): name="conv1d_ncw.x86") elif layout == "NWC": strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_conv1d_nwc), + name="conv1d_nwc.x86") else: raise ValueError("Unsupported conv1d layout {}".format(layout)) return strategy @@ -262,23 +261,16 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - - strategy.add_implementation(wrap_compute_dense(topi.nn.dense), - wrap_topi_schedule(ansor.auto_schedule_topi), - name='ansor', - plevel=ansor_plevel) - + m, _ = inputs[0].shape strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), name="dense_nopack.x86", plevel=10) - if "cblas" in target.libs: strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas), wrap_topi_schedule(topi.x86.schedule_dense_cblas), name="dense_cblas.x86", plevel=15) - m, _ = inputs[0].shape with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), @@ -291,12 +283,6 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() - - strategy.add_implementation(wrap_compute_dense(topi.nn.batch_matmul), - wrap_topi_schedule(ansor.auto_schedule_topi), - name='ansor', - plevel=ansor_plevel) - strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f2fa2b5f5b90..a37226ea4f58 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -815,27 +815,6 @@ def layout_transform(data, src_layout, dst_layout): """ return _make.layout_transform(data, src_layout, dst_layout) -def kernel_layout_transform(data, src_layout, dst_layout): - """Transform the layout of a kernel - - Parameters - ---------- - data : relay.Expr - The source tensor to be transformed - - src_layout: str - The source layout. (e.g 1N32C112H112W) - - dst_layout: str - The destination layout. (e.g. 1N2C112H112W16c) - - Returns - ------- - ret : relay.Expr - The transformed tensor. - """ - return _make.kernel_layout_transform(data, src_layout, dst_layout) - def reverse_reshape(data, newshape): """Reshapes the input array where the special values are inferred from diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 3d6883362c9b..10da37001f12 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -26,32 +26,27 @@ from . import layers from .init import create_workload -def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): +def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): """get symbol of nature dqn""" data_shape = (batch_size,) + image_shape data = relay.var("data", shape=data_shape, dtype=dtype) - bias_axis = layout.index('C') - conv1_bias = relay.var("conv1_bias") conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), - channels=32, name="conv1", data_layout=layout, - kernel_layout=layers.conv_kernel_layout(layout)) - conv1 = relay.nn.bias_add(conv1, conv1_bias, bias_axis) + channels=32, name="conv1") + conv1 = relay.nn.bias_add(conv1, conv1_bias) relu1 = relay.nn.relu(conv1) conv2_bias = relay.var("conv2_bias") conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), - channels=64, name="conv2", data_layout=layout, - kernel_layout=layers.conv_kernel_layout(layout)) - conv2 = relay.nn.bias_add(conv2, conv2_bias, bias_axis) + channels=64, name="conv2") + conv2 = relay.nn.bias_add(conv2, conv2_bias) relu2 = relay.nn.relu(conv2) conv3_bias = relay.var("conv3_bias") conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), - channels=64, name="conv3", data_layout=layout, - kernel_layout=layers.conv_kernel_layout(layout)) - conv3 = relay.nn.bias_add(conv3, conv3_bias, bias_axis) + channels=64, name="conv3") + conv3 = relay.nn.bias_add(conv3, conv3_bias) relu3 = relay.nn.relu(conv3) bf1 = relay.nn.batch_flatten(relu3) @@ -63,8 +58,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" return relay.Function(args, dense2) -def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", - layout="NCHW"): +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): """Get benchmark workload for a Deep Q Network Parameters ---------- @@ -78,11 +72,10 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo The data type Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a DQN network. params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, - layout=layout) + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype) return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index ac63afde4cba..b431dd096f9d 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -59,11 +59,9 @@ def residual_unit(data, name : str Base name of the operators """ - bn_axis = data_layout.index('C') if bottle_neck: bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, - axis=bn_axis, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( @@ -75,13 +73,13 @@ def residual_unit(data, name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), name=name + '_conv2', data_layout=data_layout, kernel_layout=kernel_layout) - bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, axis=bn_axis, name=name + '_bn3') + bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') act3 = relay.nn.relu(data=bn3) conv3 = layers.conv2d( data=act3, channels=num_filter, kernel_size=(1, 1), @@ -96,13 +94,13 @@ def residual_unit(data, data_layout=data_layout, kernel_layout=kernel_layout) return relay.add(conv3, shortcut) - bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + '_bn1') + bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( data=act1, channels=num_filter, kernel_size=(3, 3), strides=stride, padding=(1, 1), name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=num_filter, kernel_size=(3, 3), @@ -158,16 +156,12 @@ def resnet(units, data_layout = layout kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" - bn_axis = data_layout.index('C') num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) - data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, - name='bn_data') + data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') (_, _, height, _) = data_shape - if layout == "NHWC": - (_, height, _, _) = data_shape if height <= 32: # such as cifar10 body = layers.conv2d( data=data, channels=filter_list[0], kernel_size=(3, 3), @@ -178,7 +172,7 @@ def resnet(units, data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), name="conv0", data_layout=data_layout, kernel_layout=kernel_layout) - body = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn0') + body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') body = relay.nn.relu(data=body) body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), layout=data_layout) @@ -193,7 +187,7 @@ def resnet(units, body, filter_list[i+1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, data_layout=data_layout, kernel_layout=kernel_layout) - bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn1') + bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') relu1 = relay.nn.relu(data=bn1) # Although kernel is not used here when global_pool=True, we should put one pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout) @@ -215,8 +209,6 @@ def get_net(batch_size, Original author Wei Wu """ (_, height, _) = image_shape - if layout == "NHWC": - (height, _, _) = image_shape data_shape = (batch_size,) + image_shape if height <= 28: num_stages = 3 diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 967bfcdd3cde..060673dc19c6 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -279,39 +279,6 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): return _make_array(handle, False, False) -def non_empty(shape, dtype="float32", ctx=context(1, 0)): - """Create an non-empty array given shape and device - - Parameters - ---------- - shape : tuple of int - The shape of the array - - dtype : type or str - The data type of the array. - - ctx : TVMContext - The context of the array - - Returns - ------- - arr : tvm.nd.NDArray - The array tvm supported. - """ - shape = c_array(tvm_shape_index_t, shape) - ndim = ctypes.c_int(len(shape)) - handle = TVMArrayHandle() - dtype = DataType(dtype) - check_call(_LIB.TVMArrayAllocNonEmpty( - shape, ndim, - ctypes.c_int(dtype.type_code), - ctypes.c_int(dtype.bits), - ctypes.c_int(dtype.lanes), - ctx.device_type, - ctx.device_id, - ctypes.byref(handle))) - return _make_array(handle, False, False) - def from_dlpack(dltensor): """Produce an array from a DLPack tensor without memory copy. Retreives the underlying DLPack tensor's pointer to create an array from the diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6a2120817eb1..7d73bf42ab7d 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -56,11 +56,9 @@ class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): - # ndim = self.ndim - # After ansor kernel layout rewrite, len(indices) <= ndim, - # and the indices will get modified by Ansor during schedule generation. - # if len(indices) != ndim: - # raise ValueError("Need to provide %d index in tensor slice" % ndim) + ndim = self.ndim + if len(indices) != ndim: + raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) args = [] for x in indices: diff --git a/scripts/common.py b/scripts/common.py deleted file mode 100644 index e9cf58e128bb..000000000000 --- a/scripts/common.py +++ /dev/null @@ -1,1034 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Common utility for scripts""" -import argparse -import math -import os -import re -import time -from collections import defaultdict, namedtuple -from typing import Dict, List, Tuple - -import numpy as np -import matplotlib.pyplot as plt - -import topi -import tvm -from tvm import te -from tvm.ansor import (LogReader, make_workload_key_func, - register_workload_func, - write_measure_records_to_file) -from tvm.contrib import ndk, util - -############################################################ -###################### Test Workloads #################### -############################################################ - -@register_workload_func -def min_mn(M, N): - A = te.placeholder((M, N), name='A') - B = topi.min(A, axis=1) - - return [A, B] - -@register_workload_func -def argmin_mn(M, N): - A = te.placeholder((M, N), name='A') - B = topi.argmin(A, axis=1) - - return [A, B] - -@register_workload_func -def softmax_mn(M, N): - A = te.placeholder((M, N), name='A') - B = topi.nn.softmax(A, axis=1) - - return [A, B] - -@register_workload_func -def norm_bmn(B, M, N): - A = te.placeholder((B, M, N), name='A') - i = te.reduce_axis((0, M)) - j = te.reduce_axis((0, N)) - C = te.compute((B,), lambda b: te.sum(A[b][i][j] * A[b][i][j], axis=[i, j]), name='C') - D = te.compute((B,), lambda b: te.sqrt(C[b]), name='D') - - return [A, D] - -@register_workload_func -def add_mn(M, N): - A = te.placeholder((M, N), name='A') - B = te.placeholder((M, N), name='B') - C = te.compute((M, N), lambda i, j: A[i][j] + B[i][j], name='C') - - return [A, B, C] - -@register_workload_func -def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', - tensor_core_support=False): - if tensor_core_support: - A = te.placeholder((N // 16, K // 16, 16, 16), name='A', dtype=in_type) - B = te.placeholder((K // 16, M // 16, 16, 16), name='B', dtype=in_type) - k = te.reduce_axis((0, K // 16), name='k') - kk = te.reduce_axis((0, 16), name='kk') - if not ((in_type == 'float16' and out_type == 'float32') or \ - (in_type == 'int8' and out_type == 'int32')): - raise ValueError - C = te.compute((N // 16, M // 16, 16, 16), - lambda i, j, ii, jj: te.sum(A[i][k][ii][kk].astype(out_type) * B[k][j][kk][jj].astype(out_type), - axis=[k, kk]), - name='C') - else: - A = te.placeholder((N, K), name='A', dtype=in_type) - B = te.placeholder((K, M), name='B', dtype=in_type) - k = te.reduce_axis((0, K), name='k') - C = te.compute((N, M), - lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), - name='C') - - return [A, B, C] - -@register_workload_func -def dense_layer(batch, in_dim, out_dim): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - k = te.reduce_axis((0, in_dim), name='k') - C = te.compute((batch, out_dim), lambda i, j: te.sum(A[i][k] * B[j][k], axis=[k]), name='C') - - return [A, B, C] - -@register_workload_func -def max_pool_2d_nchw(N, C, H, W): - data = te.placeholder((N, C, H, W), name='data') - out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, - layout="NCHW", count_include_pad=True) - - return [data, out] - -@register_workload_func -def add_min_relu(M, N): - A = te.placeholder((M, N), name='A') - B = te.placeholder((M, N), name='B') - C = topi.add(A, B) - D = topi.min(C, axis=1) - out = topi.nn.relu(D) - return [A, B, out] - -@register_workload_func -def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, KH, KW), name='kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) - relu = topi.nn.relu(conv) - softmax = topi.nn.softmax(relu, axis=1) - out = topi.min(softmax, axis=1) - - return [data, kernel, out] - -@register_workload_func -def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, KH, KW), name='kernel') - bias = te.placeholder((CO, 1, 1), name='bias') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) - #out = topi.nn.relu(conv) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - -def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, out_dtype='float32'): - """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. - We use this in single op and subgraph evaluation because we don't want to introduce graph level optimization. - """ - assert isinstance(stride, int) or len(stride) == 2 - assert isinstance(dilation, int) or len(dilation) == 2 - - if isinstance(stride, int): - stride_h = stride_w = stride - else: - stride_h, stride_w = stride - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - batch, in_height, in_width, in_channel = Input.shape - if len(Filter.shape) == 10: - kernel_h = Filter.shape[2] * Filter.shape[6] - kernel_w = Filter.shape[3] * Filter.shape[7] - channel = Filter.shape[4] * Filter.shape[8] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] - #Filter = te.placeholder([kernel_h, kernel_w, channel, num_filter], Filter.dtype, Filter.name) - elif len(Filter.shape) == 11: - kernel_h = Filter.shape[3] * Filter.shape[7] - kernel_w = Filter.shape[4] * Filter.shape[8] - channel = Filter.shape[5] * Filter.shape[9] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] - else: - kernel_h, kernel_w, channel, num_filter = Filter.shape - - # compute the output shape - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w)) - out_channel = num_filter - out_height = topi.util.simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) - out_width = topi.util.simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - pad_before = [0, pad_top, pad_left, 0] - pad_after = [0, pad_down, pad_right, 0] - PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") - rc = te.reduce_axis((0, in_channel), name='rc') - ry = te.reduce_axis((0, kernel_h), name='ry') - rx = te.reduce_axis((0, kernel_w), name='rx') - Output = te.compute( - (batch, out_height, out_width, out_channel), - lambda nn, yy, xx, ff: te.sum( - PaddedInput[nn, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[ry, rx, rc, ff].astype(out_dtype) - , axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc") - return Output - - -@register_workload_func -def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((KH, KW, CI, CO), name='kernel') - bias = te.placeholder((CO, ), name='bias') - conv = topi.nn.conv2d_nhwc(data, kernel, strides, padding, dilation) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - -@register_workload_func -def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((KH, KW, CI, 1), name='kernel') - bias = te.placeholder((CO, ), name='bias') - conv = topi.nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - -@register_workload_func -def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((KH, KW, CI, CO), name='kernel') - bias = te.placeholder((CO, ), name='bias') - conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - - -@register_workload_func -def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') - bias = te.placeholder((CO, 1, 1), name='bias') - bn_scale = te.placeholder((CO, 1, 1), name='bn_scale') - bn_offset = te.placeholder((CO, 1, 1), name='bn_offset') - - OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) - conv = te.compute((N, CO, OH, OW), - lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], - name='bias_add') - conv = te.compute((N, CO, OH, OW), - lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], - name='bn_mul') - conv = te.compute((N, CO, OH, OW), - lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], - name='bn_add') - out = topi.nn.relu(conv) - - return [data, kernel, bias, bn_offset, bn_scale, out] - -@register_workload_func -def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') - bias = te.placeholder((CO,), name='bias') - bn_scale = te.placeholder((CO,), name='bn_scale') - bn_offset = te.placeholder((CO,), name='bn_offset') - - OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - - conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) - conv = te.compute((N, OH, OW, CO), - lambda i, j, k, l: conv[i, j, k, l] + bias[l], - name='bias_add') - conv = te.compute((N, OH, OW, CO), - lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], - name='bn_mul') - conv = te.compute((N, OH, OW, CO), - lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], - name='bn_add') - out = topi.nn.relu(conv) - - return [data, kernel, bias, bn_offset, bn_scale, out] - -resnet_conv2d_configs = { - # format : N, H, W, CI, CO, KH, KW, strides, padding, dilation - '18': [ - (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), - (1, 56, 56, 64, 128, 3, 3, (2, 2), (1, 1), (1, 1)), - (1, 56, 56, 64, 128, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 128, 256, 3, 3, (2, 2), (1, 1), (1, 1)), - (1, 28, 28, 128, 256, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 14, 14, 256, 512, 3, 3, (2, 2), (1, 1), (1, 1)), - (1, 14, 14, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), - ], - '50': [ - (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), - (1, 56, 56, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 56, 56, 256, 128, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 56, 56, 256, 64, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 56, 56, 64, 256, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 512, 1024, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 28, 28, 512, 256, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 28, 28, 512, 128, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 128, 512, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 14, 14, 1024, 2048, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 14, 14, 1024, 512, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 14, 14, 1024, 256, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 14, 14, 256, 1024, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 7, 7, 2048, 512, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 7, 7, 512, 2048, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), - ], -} - -# number of appearance for all conv2ds in resnet -resnet_conv2d_weights = { - '18': [1, 1, 1, 4, 1, 1, 1, 3, 1, 1, 3, 3], - '50': [1, 1, 1, 2, 4, 3, 1, 1, 1, 3, 4, 4, 1, 1, 5, 6, 6, 2, 3, 3], -} - - -def parse_workload_name(name: str) -> List[str]: - """Parse workload name with wildcard character and abbreviation to standard names""" - if name.startswith('matmul-'): # e.g. matmul-512, matmul-1024, matmul-+ - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [256, 512, 1024] - else: - cfg_list = [N] - return ["matmul-%s" % x for x in cfg_list] - elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = ["1-512-512", "16-512-512"] - else: - cfg_list = [N] - return ["dense-%s" % x for x in cfg_list] - elif name.startswith('min-'): # e.g. min-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["min-%s" % x for x in cfg_list] - elif name.startswith('argmin-'): # e.g. argmin-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["argmin-%s" % x for x in cfg_list] - elif name.startswith('softmax-'): # e.g. softmax-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["softmax-%s" % x for x in cfg_list] - elif name.startswith('add-'): # e.g. add-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["add-%s" % x for x in cfg_list] - elif name.startswith('norm-'): # e.g. norm-1024 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["norm-%s" % x for x in cfg_list] - elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 - N = name.split('-', maxsplit=3)[3] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["add-min-relu-%s" % x for x in cfg_list] - elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1 - res = re.match(r'nhwc-resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) - n_layers = res.group(1) - if res.group(2) == '+': - idx_list = range(len(resnet_conv2d_configs[n_layers])) - else: - idx_list = [int(res.group(2))] - - batch_size = 1 if res.group(4) is None else int(res.group(4)) - return ['nhwc-resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] - elif name.startswith('resnet-'): # e.g. resnet-50.C1, resnet-50.C1.B2, resnet-50.C+.B2 - res = re.match(r'resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) - n_layers = res.group(1) - if res.group(2) == '+': - idx_list = range(len(resnet_conv2d_configs[n_layers])) - else: - idx_list = [int(res.group(2))] - - batch_size = 1 if res.group(4) is None else int(res.group(4)) - return ['resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] - elif name in ['conv2d-bn-relu', 'conv2d-relu-softmax-min', 'max-pool-2d', 'conv2d-rewrite', 'depthwise-conv2d-rewrite']: - return [name] - else: - raise ValueError("Invalid workload " + name) - - -def get_workload_keys(name: str) -> List[str]: - """Parse workload name and return the workload keys""" - normalized_names = parse_workload_name(name) - - ret = [] - for name in normalized_names: - if name.startswith('matmul-'): - name_split = name.split('-') - in_type = out_type = 'float32' - tensor_core_support = False - if len(name_split) == 2: # e.g. matmul-512 - N = K = M = int(name_split[1]) - elif len(name_split) == 4: # e.g. matmul-32-256-512 - N = int(name_split[1]) - K = int(name_split[2]) - M = int(name_split[3]) - elif len(name_split) == 6: # e.g. matmul-32-512-512-float16-float32 - N = int(name_split[1]) - K = int(name_split[2]) - M = int(name_split[3]) - in_type = name_split[4] - out_type = name_split[5] - elif len(name_split) == 7: # e.g. matmul-32-512-512-float16-float32-tc - N = int(name_split[1]) - K = int(name_split[2]) - M = int(name_split[3]) - in_type = name_split[4] - out_type = name_split[5] - tensor_core_support = name_split[6] == "tc" - else: - raise ValueError("Invalid matmul workload") - ret.append(make_workload_key_func(matmul_nkkm, - (N, M, K, in_type, out_type, tensor_core_support))) - elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 - name_split = name.split('-') - assert len(name_split) == 4 - batch = int(name_split[1]) - in_dim = int(name_split[2]) - out_dim = int(name_split[3]) - ret.append(make_workload_key_func(dense_layer, (batch, in_dim, out_dim))) - elif name.startswith('min-'): # e.g. min-4096 - name_split = name.split('-') - if len(name_split) == 2: - M = 64 - N = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid min workload") - ret.append(make_workload_key_func(min_mn, (M, N))) - elif name.startswith('argmin-'): # e.g. argmin-4096 - name_split = name.split('-') - if len(name_split) == 2: - M = 64 - N = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid argmin workload") - ret.append(make_workload_key_func(argmin_mn, (M, N))) - elif name.startswith('softmax-'): # e.g. softmax-4096 - name_split = name.split('-') - if len(name_split) == 2: - M = 64 - N = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid softmax workload") - ret.append(make_workload_key_func(softmax_mn, (M, N))) - elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 - name_split = name.split('-') - if len(name_split) == 4: - M = 64 - N = int(name_split[3]) - elif len(name_split) == 5: - M = int(name_split[3]) - N = int(name_split[4]) - else: - raise ValueError("Invalid workload") - ret.append(make_workload_key_func(add_min_relu, (M, N))) - elif name.startswith('add-'): # e.g. add-4096 - name_split = name.split('-') - if len(name_split) == 2: - N = M = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid add workload") - ret.append(make_workload_key_func(add_mn, (M, N))) - elif name.startswith('norm-'): # e.g. norm-4096 - name_split = name.split('-') - B = 2 - if len(name_split) == 2: - N = M = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid norm workload") - ret.append(make_workload_key_func(norm_bmn, (B, M, N))) - elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1.B2 - res = re.match(r'nhwc-resnet-(\d+).C(\d+).B(\d+)', name) - n_layers = res.group(1) - idx = int(res.group(2)) - batch_size = 1 if res.group(3) is None else int(res.group(3)) - args = list(resnet_conv2d_configs[n_layers][idx]) - args[0] = batch_size - ret.append(make_workload_key_func(conv2d_nhwc_bias, args)) - elif name.startswith('resnet-'): # e.g. resnet-50.C1.B2 - res = re.match(r'resnet-(\d+).C(\d+).B(\d+)', name) - n_layers = res.group(1) - idx = int(res.group(2)) - batch_size = 1 if res.group(3) is None else int(res.group(3)) - args = list(resnet_conv2d_configs[n_layers][idx]) - args[0] = batch_size - ret.append(make_workload_key_func(conv2d_nchw_bias, args)) - elif name == 'max-pool-2d': - return [make_workload_key_func(max_pool_2d_nchw, (2, 512, 7, 7))] - elif name == 'conv2d-bn-relu': - return [make_workload_key_func(conv2d_nhwc_bn_relu, - (1, 7, 7, 512, 512, 3, 1, 1, 1)) ] - elif name == 'conv2d-rewrite': - return [ make_workload_key_func(conv2d_nhwc_bias_with_rewrite, - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] - elif name == 'depthwise-conv2d-rewrite': - return [ make_workload_key_func(depthwise_conv2d_nhwc_bias_with_rewrite, - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] - elif name == 'conv2d-relu-softmax-min': - return [make_workload_key_func(conv2d_relu_softmax_min, - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] - else: - raise ValueError("Invalid workload " + name) - - return ret - - -def get_workload_weights(name: str) -> List[float]: - """Return weights for workload name""" - if name.startswith('resnet-'): - res = re.match(r'resnet-(\d+).C+', name) - n_layers = res.group(1) - return np.array(resnet_conv2d_weights[n_layers]) - else: - return np.ones(len(get_workload_keys(name))) - - -############################################################ -###################### Measure Tools #################### -############################################################ - - -def measure_schedule(s, - bufs, - target, - target_host=None, - remote=None, - ndk_cc=None, - number=10, - repeat=3, - min_repeat_ms=500): - """Measure the time cost of a schedule""" - func = tvm.build(s, bufs, target=target, target_host=target_host) - if remote: - ctx = remote.context(str(target), 0) - temp = util.tempdir() - remote_path = temp.relpath("tmp_deploy_lib.so") - os.environ['TVM_NDK_CC'] = ndk_cc - func.export_library(remote_path, ndk.create_shared) - remote.upload(remote_path) - func = remote.load_module("tmp_deploy_lib.so") - else: - ctx = tvm.context(str(target), 0) - - if os.environ.get('TVM_AUTO_CACHE_FLUSH', '0') == '1': - min_repeat_ms = 0 - number = 1 - - time_f = func.time_evaluator(func.entry_name, - ctx, - number=number, - repeat=repeat, - min_repeat_ms=min_repeat_ms) - - np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] - args = [tvm.nd.array(x, ctx=ctx) for x in np_args] - ctx.sync() - - costs = time_f(*args).results - - return costs - -def check_correctness(s, bufs, s_ref, buf_ref, target, target_host=None, remote=None, ndk_cc=None): - """Check the correctness of a schedule against a reference schedule""" - func = tvm.build(s, bufs, target=target, target_host=target_host) - func_ref = tvm.build(s_ref, buf_ref, target='llvm') - - if remote: - raise NotImplemented - else: - ctx = tvm.context(str(target), 0) - ctx_ref = tvm.cpu() - - np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] - args = [tvm.nd.array(x, ctx=ctx) for x in np_args] - args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] - ctx.sync() - - func(*args) - func_ref(*args_ref) - - for arr, arr_ref in zip(args, args_ref): - np.testing.assert_allclose(arr.asnumpy(), arr_ref.asnumpy()) - - -############################################################ -##################### Other Utilities #################### -############################################################ - - -def geomean(xs): - """Compute geometric mean""" - return math.exp(math.fsum(math.log(x) for x in xs) / len(xs)) - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -global last_tic -last_tic = None - - -def PRINT_TIME(msg): - """Print time interval between differnt calls. This is for debug so we make the name letters capital""" - global last_tic - now = time.time() - - if last_tic is None: - last_tic = now - - print(msg, now - last_tic) - last_tic = now - - -############################################################ -###################### I/O Utilities ##################### -############################################################ - -# The format for a line in resulst file -BenchmarkRecord = namedtuple("BenchmarkRecord", [ - 'device', 'backend', 'workload_type', 'workload_name', 'library', 'algorithm', 'value', - 'time_stamp' -]) - - -class BaselineDatabase: - """A class for query records in baseline database""" - def __init__(self, filename): - self.filename = filename - - self.lines = [] - for line in open(filename): - if line.startswith('#') or line.isspace(): - continue - self.lines.append(line.split('\t')) - - def filter_records(self, devices=None, backends=None, wkl_names=None, libraries=None): - ret = [] - for line in self.lines: - line = BenchmarkRecord(*line) - - if devices is not None and line.device not in devices: - continue - if backends is not None and line.backend not in backends: - continue - if wkl_names is not None and line.workload_name not in wkl_names: - continue - if libraries is not None and line.library not in libraries: - continue - - ret.append(line) - return ret - - def get_data_dict(self, device, target, wkl_names) -> Tuple[Dict, List]: - """Return a data dict s.t. data[wkl][library] = cost""" - data = defaultdict(lambda: defaultdict(lambda: 1e10)) - - all_libraries = set() - - if "cpu" in target.keys: - backends = ['cpu'] - elif "gpu" in target.keys: - backends = ['gpu'] - else: - raise ValueError("Invalid target: " + target) - - # Read costs for baselines - records = self.filter_records(devices=[device], backends=backends, wkl_names=wkl_names) - for record in records: - # use min over (possible) multiple algorithms - all_libraries.add(record.library) - data[record.workload_name][record.library] = \ - min(data[record.workload_name][record.library], - np.mean(eval(record.value)['costs'])) - - return data, list(all_libraries) - - -class LogFileDatabase: - """A class for indexing best records in a log file""" - def __init__(self, filename: str, n_lines: int = -1): - inputs, results = LogReader(filename).read_lines(n_lines) - - # best records, search by (target_key, workload_key). e.g. ('gpu', 'conv2d...') - self.best_by_targetkey = {} - - # best according to (model, workload_key). e.g. ('1080ti', 'conv2d...')) - self.best_by_model = {} - - # find best records and build the index - for inp, res in zip(inputs, results): - if res.error_no != 0: - continue - - # use target keys in tvm target system as key to build best map - for target_key in inp.task.target.keys: - key = (target_key, inp.task.workload_key) - if key not in self.best_by_targetkey: - self.best_by_targetkey[key] = (inp, res) - else: - _, other_res = self.best_by_targetkey[key] - if np.mean([x.value for x in other_res.costs]) > \ - np.mean([x.value for x in res.costs]): - self.best_by_targetkey[key] = (inp, res) - - # use model as key to build best map - key = (inp.task.target.model, inp.task.workload_key) - if key not in self.best_by_model: - if inp.task.target.model != 'unknown': - self.best_by_model[key] = (inp, res) - else: - _, other_res = self.best_by_model[key] - if np.mean([x.value for x in other_res.costs]) > \ - np.mean([x.value for x in res.costs]): - self.best_by_model[key] = (inp, res) - - def write_best(self, filename: str): - best_records = list(self.best_by_targetkey.values()) - inputs = [x[0] for x in best_records] - results = [x[1] for x in best_records] - write_measure_records_to_file(filename, inputs, results) - - -############################################################ -###################### Plot Utilities #################### -############################################################ - -def max_curve(raw_curve): - """Return b[i] = max(a[:i]) """ - ret = [] - cur_max = -np.inf - for x in raw_curve: - cur_max = max(cur_max, x) - ret.append(cur_max) - return ret - -def min_curve(raw_curve): - """Return b[i] = min(a[:i]) """ - ret = [] - cur_min = np.inf - for x in raw_curve: - cur_min = min(cur_min, x) - ret.append(cur_min) - return ret - -def mean_curve(raw_curve, window_size=None): - """Return b[i] = mean(a[:i]) """ - ret = [] - mean = 0 - if window_size is None: - for i, x in enumerate(raw_curve): - mean = (mean * i + x) / (i + 1) - ret.append(mean) - else: - for i, x in enumerate(raw_curve): - if i >= window_size: - mean = (mean * window_size + x - raw_curve[i - window_size]) / window_size - else: - mean = (mean * i + x) / (i + 1) - ret.append(mean) - return ret - - -def enhance_color(color, h=1, l=1, s=1): - """Make color looks better for pyplot""" - import matplotlib.colors as mc - import colorsys - try: - c = mc.cnames[color] - except: - c = color - c = np.array(colorsys.rgb_to_hls(*mc.to_rgb(c))) - - h, l, s = h * c[0], l * c[1], s * c[2] - h, l, s = [max(min(x, 1), 0) for x in [h, l, s]] - - return colorsys.hls_to_rgb(h, l, s) - - -method_color_dict = { - 'ours': 'C0', - 'AutoTVM': 'C1', - - 'tensorflow': 'C2', - 'tensorflow-tensorrt': 'C9', - 'tflite': 'C2', - - 'pytorch': enhance_color('C3', l=1.1, s=0.9), - - 'FlexTensor': enhance_color('C5'), - 'halide': enhance_color('teal', l=1.25), - - 'Limit space': 'C7', - 'No fine-tuning': 'C8', - 'No task scheduler': 'C1', -} - -def method2color(method): - if '-batch-' in method: - method, batch_size = method.split('-batch-') - #return enhance_color(method_color_dict[method], s=1.1, l=1.5) - return method_color_dict[method] - else: - return method_color_dict[method] - -method_order_list = [ - 'pytorch', 'tensorflow', 'tensorflow-xla', 'tensorflow-tensorrt', - 'tflite', 'halide', 'FlexTensor', 'AutoTVM', - - 'Limit space', 'No fine-tuning', - 'ours', -] - -def method2order(method): - if '-batch-' in method: - method, batch_size = method.split('-batch-') - batch_size = int(batch_size) - return method_order_list.index(method) + batch_size / 100 - else: - return method_order_list.index(method) - -show_name_replace_dict = { - 'pytorch': "PyTorch", - 'tensorflow-tensorrt': 'TensorRT-TF', - 'tensorflow': 'TensorFlow', - 'tflite': 'TensorFlow Lite', - 'halide': 'Halide', - - 'ours': 'Ansor (ours)', - 'batch-16': 'batch', - - 'resnet_50': 'ResNet-50', - 'mobilenet_v2': 'Mobilenet V2', - 'resnet_18_3d': '3D-ResNet', - 'dcgan': 'DCGAN', - 'dqn': 'DQN', - 'bert': 'BERT', -} - -def show_name(name): - # if name.startswith('resnet-'): - # return name.split('.')[1] - for key, value in show_name_replace_dict.items(): - name = name.replace(key, value) - - return name - -def draw_grouped_bar_chart(data, baseline='pytorch', output='out.png', - yscale_log=False, yticks=None, y_max=None, - legend_bbox_to_anchor=None, legend_nrow=None, - figure_size=None, figax=None, draw_ylabel=True, draw_legend=True): - width = 1 - gap = 1.5 - fontsize = 19 - xticks_font_size = fontsize - 2 - - figure_size = figure_size or (11, 4) - legend_bbox_to_anchor = legend_bbox_to_anchor or (0.45, 1.35) - - all_methods = set() - legend_set = {} - - if figax is None: - fig, ax = plt.subplots() - axes = [] - axes.append(ax) - else: - ax = figax - - x0 = 0 - xticks = [] - xlabels = [] - - workloads = list(data.keys()) - for wkl in workloads: - ys = [] - colors = [] - - methods = list(data[wkl].keys()) - - if baseline in data[wkl]: - baseline_cost = data[wkl][baseline] - else: - # normalize to best library - baseline_cost = 1e10 - for method in methods: - if data[wkl][method] < baseline_cost: - baseline_cost = data[wkl][method] - - methods.sort(key=lambda x: method2order(x)) - for method in methods: - relative_speedup = baseline_cost / data[wkl][method] - if yticks is None: - ys.append(relative_speedup) - else: - ys.append(max(relative_speedup, yticks[0] * 1.1)) - colors.append(method2color(method)) - - # draw the bars - xs = np.arange(x0, x0 + len(ys)) - bars = ax.bar(xs, ys, width=width, color=colors) - - for method, bar_obj in zip(methods, bars): - all_methods.add(method) - if method not in legend_set: - legend_set[method] = bar_obj - - # tick and label - x0 += len(ys) + gap - - xticks.append(x0 - gap - len(ys)*width/2.0 - width/2.0) - xlabels.append(show_name(wkl)) - - ax.set_xticks(xticks) - ax.set_xticklabels(xlabels, fontsize=xticks_font_size) - plt.tick_params(axis='x', which='both', bottom='off', top='off') - - if draw_ylabel is True: - ax.set_ylabel('Relative Speedup', fontsize=fontsize) - elif isinstance(draw_ylabel, str): - ax.set_ylabel(draw_ylabel, fontsize=fontsize) - - if yscale_log: - ax.set_yscale('log', basey=2) - if yticks is not None: - ax.set_yticks(yticks) - if y_max: - ax.set_ylim(top=y_max) - - from matplotlib.ticker import FormatStrFormatter - ax.set_yticklabels(ax.get_yticks(), fontsize=fontsize) - ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) - ax.yaxis.grid(linewidth=0.4, linestyle='dotted') # draw grid line - ax.set_axisbelow(True) # grid lines are behind the rest - ax.tick_params(bottom=False, top=False, right=False) - - # put legend outside the plot - all_methods = list(all_methods) - all_methods.sort(key=lambda x : method2order(x)) - - if draw_legend: - legend_nrow = legend_nrow or 2 - ncol = (len(all_methods) + legend_nrow - 1)// legend_nrow - ax.legend([legend_set[x] for x in all_methods], - [show_name(x) for x in all_methods], - fontsize=fontsize-1, - loc='upper center', - bbox_to_anchor=legend_bbox_to_anchor, - ncol=ncol, - handlelength=1.0, - handletextpad=0.5, - columnspacing=1.1) - - if figax is None: - fig.set_size_inches(figure_size) - fig.savefig(output, bbox_inches='tight') - print("Output the plot to %s" % output) - - -def to_str_round(x, decimal=6): - if isinstance(x, str): - return x - if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): - return "[" + ", ".join([to_str_round(y, decimal=decimal) - for y in x]) + "]" - if isinstance(x, dict): - return str({k: eval(to_str_round(v)) for k, v in x.items()}) - if isinstance(x, int): - return str(x) - if isinstance(x, float): - format_str = "%%.%df" % decimal - return format_str % x - raise ValueError("Invalid value: " + str(x)) - diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py deleted file mode 100644 index db6b3b9dc9aa..000000000000 --- a/scripts/shape_configs.py +++ /dev/null @@ -1,247 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Shape configurations for single operator / subgraph evaluation -This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ -""" - -matmul_shapes = [ - (1, 128, 128, 128), - (1, 512, 32, 512), - (1, 512, 512, 512), - (1, 1024, 1024, 1024), -] - -conv1d_shapes = [ - # derived from conv2d_shapes - (1, 256, 64, 128, 3, 2, 1), -# (1, 256, 64, 128, 1, 2, 0), -# (1, 256, 64, 64, 1, 1, 0), -# (1, 128, 128, 256, 3, 2, 1), - (1, 128, 128, 256, 1, 2, 0), -# (1, 128, 128, 128, 3, 1, 1), -# (1, 64, 256, 512, 3, 2, 1), -# (1, 64, 256, 512, 1, 2, 0), - (1, 64, 256, 256, 5, 1, 2), - (1, 32, 512, 512, 3, 1, 1), -] - -conv2d_shapes = [ - # all conv2d layers in resnet-18 - (1, 224, 224, 3, 64, 7, 2, 3), -# (1, 56, 56, 64, 128, 3, 2, 1), -# (1, 56, 56, 64, 128, 1, 2, 0), -# (1, 56, 56, 64, 64, 3, 1, 1), - (1, 56, 56, 64, 64, 1, 1, 0), -# (1, 28, 28, 128, 256, 3, 2, 1), -# (1, 28, 28, 128, 256, 1, 2, 0), -# (1, 28, 28, 128, 128, 3, 1, 1), -# (1, 14, 14, 256, 512, 3, 2, 1), -# (1, 14, 14, 256, 512, 1, 2, 0), - (1, 14, 14, 256, 256, 3, 1, 1), - (1, 7, 7, 512, 512, 3, 1, 1), -] - -conv3d_shapes = [ - # Derived from cnov2d_shapes. Use depth=16 for all configurations - (1, 16, 224, 224, 3, 64, 7, 2, 3), -# (1, 16, 56, 56, 64, 128, 3, 2, 1), -# (1, 16, 56, 56, 64, 128, 1, 2, 0), -# (1, 16, 56, 56, 64, 64, 3, 1, 1), - (1, 16, 56, 56, 64, 64, 1, 1, 0), -# (1, 16, 28, 28, 128, 256, 3, 2, 1), -# (1, 16, 28, 28, 128, 256, 1, 2, 0), -# (1, 16, 28, 28, 128, 128, 3, 1, 1), -# (1, 16, 14, 14, 256, 512, 3, 2, 1), -# (1, 16, 14, 14, 256, 512, 1, 2, 0), - (1, 16, 14, 14, 256, 256, 3, 1, 1), - (1, 16, 7, 7, 512, 512, 3, 1, 1), -] - -group_conv2d_shapes = [ - # Derived from cnov2d_shapes. Use group=4 for all configurations - (1, 56, 56, 64, 128, 3, 2, 1 , 1, 4), -# (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), -# (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), - (1, 56, 56, 64, 64, 1, 1, 0 , 1, 4), -# (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), -# (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), -# (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), -# (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), -# (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), - (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), - (1, 7, 7, 512, 512, 3, 1, 1 , 1, 4), -] - -dilation_conv2d_shapes = [ - # Derived from cnov2d_shapes. Use dilation=2 for all configurations - (1, 224, 224, 3, 64, 7, 2, 3 , 2), -# (1, 56, 56, 64, 128, 3, 2, 1 , 2), -# (1, 56, 56, 64, 128, 1, 2, 0 , 2), -# (1, 56, 56, 64, 64, 3, 1, 1 , 2), - (1, 56, 56, 64, 64, 1, 1, 0 , 2), -# (1, 28, 28, 128, 256, 3, 2, 1, 2), -# (1, 28, 28, 128, 256, 1, 2, 0, 2), -# (1, 28, 28, 128, 128, 3, 1, 1, 2), -# (1, 14, 14, 256, 512, 3, 2, 1, 2), -# (1, 14, 14, 256, 512, 1, 2, 0, 2), - (1, 14, 14, 256, 256, 3, 1, 1, 2), - (1, 7, 7, 512, 512, 3, 1, 1 , 2), -] - -depthwise_conv2d_shapes = [ - # all depthwise conv2d layers in mobilenet - (1, 112, 112, 32, 3, 1, 1), - (1, 112, 112, 64, 3, 2, 1), -# (1, 56, 56, 128, 3, 1, 1), -# (1, 56, 56, 128, 3, 2, 1), -# (1, 28, 28, 256, 3, 1, 1), -# (1, 28, 28, 256, 3, 2, 1), -# (1, 14, 14, 512, 3, 1, 1), - (1, 14, 14, 512, 3, 2, 1), - (1, 7, 7, 1024, 3, 1, 1), -] - -conv2d_transpose_shapes = [ - # all conv2d tranpose layers in DCGAN - (1, 4, 4, 512, 256, 4, 2, 1), - (1, 8, 8, 256, 128, 4, 2, 1), - (1, 16, 16, 128, 64, 4, 2, 1), - (1, 32, 32, 64, 3, 4, 2, 1), -] - -conv2d_capsule_shapes = [ - # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) - (1, 16, 16, 32, 32, 3, 2, 1), - (1, 8, 8, 32, 32, 3, 1, 1), - (1, 16, 16, 8, 16, 3, 2, 1), - (1, 8, 8, 16, 16, 3, 1, 1), -] - -conv2d_winograd_nhwc_shapes = [ - (1, 56, 56, 64, 64, 3, 1, 1), - (1, 28, 28, 128, 128, 3, 1, 1), - (1, 14, 14, 256, 256, 3, 1, 1), - (1, 7, 7, 512, 512, 3, 1, 1), -] - -conv2d_winograd_nchw_shapes = [ - (1, 64, 56, 56, 64, 3, 1, 1), - (1, 128, 28, 28, 128, 3, 1, 1), - (1, 256, 14, 14, 256, 3, 1, 1), - (1, 512, 7, 7, 512, 3, 1, 1), -] - -matmul_tensor_core_shapes = [ - (16, 512, 512, 'float16', 'float32', True), - (32, 512, 512, 'float16', 'float32', True), - (512, 512, 512, 'float16', 'float32', True), -] - -norm_shapes = [ - (1, 256, 256), - (1, 512, 512), - (1, 1024, 1024), - (1, 4096, 1024), -] - -single_op_shape_dict = { - 'C1D': conv1d_shapes, - 'C2D': conv2d_shapes, - 'C3D': conv3d_shapes, - 'GMM': matmul_shapes, - 'GRP': group_conv2d_shapes, - 'DIL': dilation_conv2d_shapes, - 'DEP': depthwise_conv2d_shapes, - 'T2D': conv2d_transpose_shapes, - 'CAP': conv2d_capsule_shapes, - 'NRM': norm_shapes, - -# The following workloads are not in our sinle op evaluation plan. -# They should be moved to `common.py` and be used by `tune_wkl.py`. -# 'C2D_NCHW': conv2d_nchw_shapes, -# 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, -# 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, -# 'GMM_TC': matmul_tensor_core_shapes, -} - -conv2d_bn_relu_shapes = [ - (1, 224, 224, 3, 64, 7, 2, 3), - (1, 56, 56, 64, 128, 3, 2, 1), - (1, 28, 28, 128, 256, 1, 2, 0), - (1, 7, 7, 512, 512, 3, 1, 1, 1), - (16, 224, 224, 3, 64, 7, 2, 3), - (16, 56, 56, 64, 128, 3, 2, 1), - (16, 28, 28, 128, 256, 1, 2, 0), - (16, 7, 7, 512, 512, 3, 1, 1, 1), -] - -transpose_batch_matmul_shapes = [ - (1, 128, 12, 64), - (1, 128, 16, 64), - (1, 64, 12, 128), - (1, 128, 12, 128), - (16, 128, 12, 64), - (16, 128, 16, 64), - (16, 64, 12, 128), - (16, 128, 12, 128), -] - -subgraph_shape_dict = { - "conv2d_bn_relu": conv2d_bn_relu_shapes, - "transpose_batch_matmul": transpose_batch_matmul_shapes, -} - -resnet_shapes = [ - (1, ), - (16, ), -] - -mobilenet_v2_shapes = [ - (1, ), - (16, ), -] - -dcgan_shapes = [ - (1, ), - (16, ), -] - -dqn_shapes = [ - (1, ), - (16, ), -] - -bert_shapes = [ - (1, ), - (16, ), -] - -resnet18_3d_shapes = [ - (1, ), - (16, ), -] - -network_shape_dict = { - 'resnet_50': resnet_shapes, - 'mobilenet_v2': mobilenet_v2_shapes, - 'dcgan': dcgan_shapes, - 'dqn': dqn_shapes, - 'bert': bert_shapes, - 'resnet_18_3d': resnet18_3d_shapes, -} - diff --git a/scripts/tune_network.py b/scripts/tune_network.py deleted file mode 100644 index 188da6cbe6e6..000000000000 --- a/scripts/tune_network.py +++ /dev/null @@ -1,405 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Tune a whole neural network""" -import argparse -import logging -import random -import os -import numpy as np - -import tvm -from tvm import ansor, relay -import tvm.contrib.graph_runtime as runtime -from tvm.contrib.debugger import debug_runtime -from tvm.contrib import util, ndk -from tvm.relay import testing -from tvm.ansor.utils import request_remote -#from baseline.utils import log_line, BenchmarkRecord - -from common import str2bool -from tune_test import create_tune_option - -dtype = "float32" - -def get_network(name, network_path, batch_size, layout): - """Get the relay module and random weights for a network""" - input_shape = (batch_size, 3, 224, 224) - output_shape = (batch_size, 1000) - input_name = 'data' - - if name.startswith("resnet3d"): - n_layer = int(name.split('-')[1]) - layout = "NDHWC" - image_shape = (16, 112, 112, 3) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.resnet3d.get_workload(num_layers=n_layer, batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) - elif name.startswith("resnet"): - n_layer = int(name.split('-')[1]) - image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) - elif "lstm" in name: - mod, params = relay.testing.lstm.get_workload(iterations=10, num_hidden=512, batch_size=batch_size, dtype=dtype) - elif "mlp" in name: - input_shape = (batch_size, 1, 28, 28) - mod, params = relay.testing.mlp.get_workload(batch_size=batch_size, dtype=dtype) - elif "vgg" in name: - n_layer = int(name.split('-')[1]) - mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) - elif name == 'dcgan': - input_shape = (batch_size, 100) - mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size) - elif name == 'dqn': - layout = "NHWC" - image_shape = (84, 84, 4) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) - elif name == 'mobilenet': - image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) - elif name == 'r3d_18': - import torch - import torchvision - - model = getattr(torchvision.models.video, name)(pretrained=False) - model = model.eval() - - # We grab the TorchScripted model via tracing - input_shape = [batch_size, 3, 16, 112, 112] - input_data = torch.randn(input_shape) - scripted_model = torch.jit.trace(model, input_data).eval() - - input_name = 'input0' # only one input, set it to this name - shape_list = {input_name: input_shape} - mod, params = relay.frontend.from_pytorch(scripted_model, - shape_list) - elif name == 'squeezenet_v1.1': - mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) - elif name == 'inception_v3': - input_shape = (batch_size, 3, 299, 299) - mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) - elif name == 'mxnet': - # an example for mxnet model - from mxnet.gluon.model_zoo.vision import get_model - block = get_model('resnet18_v1', pretrained=True) - mod, params = relay.frontend.from_mxnet(block, shape={"input_name": input_shape}, dtype=dtype) - net = mod["main"] - net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - mod = relay.Module.from_expr(net) - elif name == 'tflite-mobilenet-v2' or name == 'tflite-resnet-v2-50': - try: - import tflite.Model - except ImportError: - raise ImportError("The tflite package must be installed") - input_name = "input" - input_shape = (1, 224, 224, 3) - output_shape = (1, 1001) - input_dtype = "float32" - tflite_model_buf = open(network_path, "rb").read() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - mod, params = relay.frontend.from_tflite(tflite_model, - shape_dict={input_name: input_shape}, - dtype_dict={input_name: input_dtype}) - elif name == 'pytorch-mobilenet-v2': - import torch - - model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', pretrained=False) - model.eval() - - input_shape = [batch_size, 3, 224, 224] - input_data = torch.randn(input_shape) - scripted_model = torch.jit.trace(model, input_data).eval() - - input_name = 'input0' - shape_list = {input_name: input_shape} - mod, params = relay.frontend.from_pytorch(scripted_model, - shape_list) - elif name == 'bert': - import tensorflow as tf - - bert_pb = './baseline/tensorflow/tf_models/bert/bert-B%d.pb' % batch_size - try: - with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - except: - raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") - - input_shape = (batch_size, 128) - input_name = ['input'] - shape_dict = { - 'input': input_shape - } - out_names = [ - 'bert/pooler/dense/Tanh' - ] - - mod, params = relay.frontend.from_tensorflow(graph_def, - shape=shape_dict, - outputs=out_names) - else: - raise ValueError("Unsupported network: " + name) - - return mod, params, input_name, input_shape, output_shape - - -def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, - local_measure, ndk_cc, rpc_device_key, rpc_host, rpc_port, rpc_num_threads, seed=43): - if local_measure: - if target.target_name == "cuda": - ctx = tvm.gpu() - else: - ctx = tvm.cpu() - else: - print("=============== Request Remote ===============") - if 'TVM_NDK_CC' not in os.environ: - os.environ['TVM_NDK_CC'] = ndk_cc - remote = request_remote(rpc_device_key, rpc_host, rpc_port) - - print("=============== Export ===============") - ctx = remote.cpu() - temp = util.tempdir() - path_lib = temp.relpath("deploy_lib.so") - lib.export_library(path_lib, ndk.create_shared) - - print("=============== Upload ===============") - remote.upload(path_lib) - - print("=============== Load ===============") - lib = remote.load_module("deploy_lib.so") - - if rpc_num_threads: - config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, rpc_num_threads) - - np.random.seed(seed) - data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) - if debug_profile: - module = debug_runtime.create(graph, lib, ctx) - else: - module = runtime.create(graph, lib, ctx) - - if type(input_name) == list: - for name in input_name: - module.set_input(name, data_tvm) - else: - module.set_input(input_name, data_tvm) - for k, v in params.items(): - module.set_input(k, v) - - return module, ctx - - -def tune_and_evaluate(network_arguments, target, target_host, - search_policy, task_scheduler_arguments, tune_option_arguments, - tune, debug_profile, check_correctness, log_n_lines): - # Extract tasks from relay program - mod, params, input_name, data_shape, out_shape = get_network(**network_arguments) - - # Tune all - if tune: - print("=============== Extract Workloads ===============") - workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) - print("Extract %d workloads in total" % (len(workloads))) - - # Tune workloads with auto scheduler - print("=============== Tune ===============") - tasks = [] - for i, wkl_key in enumerate(workloads): - dag = ansor.workload_key_to_dag(wkl_key) - print("[========= Task %d =========]\n" % i, dag) - tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) - - tuner = ansor.SimpleTaskScheduler(tasks, - lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), - **task_scheduler_arguments) - tune_option, measure_ctx = create_tune_option(target, **tune_option_arguments) - - if tune_option_arguments['local_measure'] and target.target_name != 'cuda': - os.environ['TVM_BIND_MASTER_CORE_0'] = "1" - tuner.tune(tune_option, search_policy) - - if measure_ctx: - del measure_ctx - - kernel_layout_rewrite = True - - # Compile graph with best states found by auto-scheduler - print("=============== Compile ===============") - with ansor.apply_history_best(tune_option_arguments['log_file'], log_n_lines): - os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" - - if kernel_layout_rewrite: - ansor.prepare_layout_rewrite(mod, target=target, params=params) - else: - # disable layout rewrite - ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build( - mod, target=target, params=params) - - ansor.finish_layout_rewrite() - print("=============== Compile Finish ===============") - - module, ctx = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **common_measure_parameters) - - # Evaluate - print("========== Evaluate ==========") - ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) - prof_res = np.array(ftimer().results) - - # display profile information - if debug_profile or check_correctness: - module.run() - if check_correctness: - actual_output = module.get_output(0).asnumpy() - print(actual_output) - - print("Mean inference time (std dev): %.2f ms (%.2f ms)" % - (np.mean(prof_res) * 1000, np.std(prof_res) * 1000)) - #log_line(BenchmarkRecord(target.target_name, 'gpu' if target.target_name == 'cuda' else 'cpu', 'network', - # "%s.B%d" % (network_name, batch_size), 'AutoSchedule', layout, - # {"costs": prof_res}, time.time()), record_file) - - if check_correctness: - print("========== Check Correctness ==========") - # clean relay cache - relay.backend.compile_engine.get().clear() - - # disable layout rewrite - ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - target = tvm.target.create('llvm') - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build( - mod, target=target, params=params) - - module, _ = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **common_measure_parameters) - module.run() - - expected_output = module.get_output(0).asnumpy() - np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # Search task related arguments - parser.add_argument("--network", type=str, required=True) - parser.add_argument("--network-path", type=str, default=None, help="The path of tflite model") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--layout", type=str, default='NHWC') - parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') - parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - - # Search strategy related arguments - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--policy", type=str, choices=['sketch'], default='sketch') - parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='gradient', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') - parser.add_argument("--seed", type=int, default=0, help='random seed') - - # Log file related arguments - parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") - parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") - parser.add_argument("--log-n-lines", type=int, help="Only load the first n lines for history log") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - - # Measurement related and other arguments - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=10) - parser.add_argument("--early-stopping", type=int, default=-1) - parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--rpc-device-key", type=str, default=None) - parser.add_argument("--rpc-host", type=str, default='0.0.0.0') - parser.add_argument("--rpc-port", type=int, default=9190) - parser.add_argument("--rpc-num-threads", type=int, default=None) - parser.add_argument("--n-parallel", type=int, default=1) - parser.add_argument("--ndk-cc", type=str, default=None) - args = parser.parse_args() - - np.random.seed(args.seed) - random.seed(args.seed) - logging.basicConfig() - logging.getLogger('ansor').setLevel(logging.DEBUG) - os.environ["TOPHUB_LOCATION"] = "NONE" # disable autotvm - - target = tvm.target.create(args.target) - log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, - target.target_name) - load_log_file = args.load_log or log_file - search_policy = "%s.%s" % (args.policy, args.model_type) - if args.layout: - layout = args.layout - elif target.target_name == "cuda": - layout = "NCHW" - else: - layout = "NHWC" - - network_arguments = { - 'name': args.network, - 'network_path': args.network_path, - 'batch_size': args.batch_size, - 'layout': layout - } - - task_scheduler_parameters = { - 'strategy': args.task_scheduler, - 'load_log_file': load_log_file, - 'load_model_file': args.load_model, - 'verbose': args.verbose, - } - - common_measure_parameters = { - 'local_measure': args.local_measure, - 'rpc_device_key': args.rpc_device_key, - 'rpc_host': args.rpc_host, - 'rpc_port': args.rpc_port, - 'rpc_num_threads': args.rpc_num_threads, - 'ndk_cc': args.ndk_cc, - } - - tune_option_arguments = { - 'log_file': log_file, - 'n_trials': args.n_trials, - 'num_measure_per_iter': args.num_measure_per_iter, - 'verbose': args.verbose, - 'n_parallel': args.n_parallel, - 'build_timeout': args.build_timeout, - 'run_timeout': args.run_timeout, - 'early_stopping': args.early_stopping, - **common_measure_parameters - } - - tune_and_evaluate(network_arguments, target, args.target_host, - search_policy, task_scheduler_parameters, tune_option_arguments, - args.tune, args.debug_profile, args.check_correctness, - args.log_n_lines) diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py deleted file mode 100644 index d3e70501873e..000000000000 --- a/scripts/tune_op_subgraph.py +++ /dev/null @@ -1,602 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Tune all workloads for single op & subgraph evaluation""" -import argparse -import logging -import random - -import numpy as np - -import tvm -from tvm import te, ansor -import topi -from topi.nn.winograd_util import winograd_transform_matrices -from topi.util import get_const_tuple - -from common import measure_schedule, str2bool, norm_bmn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu -from shape_configs import single_op_shape_dict, subgraph_shape_dict -from tune_test import tune_workloads_jointly, replay_workload, create_tune_option - -# ========================== Single Ops ========================== - -@ansor.register_workload_func -def batch_matmul_nkkm(B, N, M, K): - X = te.placeholder((B, N, K), name='A') - Y = te.placeholder((B, K, M), name='B') - k = te.reduce_axis((0, K), name='k') - Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') - return [X, Y, Z] - -@ansor.register_workload_func -def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, L, CI), name='inputs') - weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') - - batch_size, in_len, in_channel = inputs.shape - k_len, channel_per_group, out_channel = weight.shape - out_channel_per_group = out_channel // groups - out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 - rc = te.reduce_axis((0, channel_per_group), name='rc') - rl = te.reduce_axis((0, k_len), name='rl') - - padded = topi.nn.pad(inputs, [0, padding, 0]) - output = te.compute( - (batch_size, out_len, out_channel), - lambda n, l, co: te.sum( - (padded[n, l * stride + rl * dilation, co // out_channel_per_group * channel_per_group + rc] * - weight[rl, rc, co]), axis=[rl, rc]), - name='conv1d_nlc' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, H, W, CI), name='inputs') - weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') - batch_size, in_h, in_w, in_channel = inputs.shape - k_h, k_w, channel_per_group, out_channel = weight.shape - out_channel_per_group = out_channel // groups - - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rh = te.reduce_axis((0, k_h), name="rh") - rw = te.reduce_axis((0, k_w), name="rw") - rc = te.reduce_axis((0, channel_per_group), name="rc") - - padded = topi.nn.pad(inputs, [0, padding, padding, 0]) - output = te.compute( - (batch_size, out_h, out_w, out_channel), - lambda n, h, w, co: te.sum( - (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, - co // out_channel_per_group * channel_per_group + rc] - * weight[rh, rw, rc, co]), axis=[rh, rw, rc] - ), - name='conv2d_nhwc' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, CI, H, W), name='inputs') - weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') - batch_size, in_channel, in_h, in_w = inputs.shape - out_channel, channel_per_group, k_h, k_w, = weight.shape - out_channel_per_group = out_channel // groups - - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rc = te.reduce_axis((0, channel_per_group), name="rc") - rh = te.reduce_axis((0, k_h), name="rh") - rw = te.reduce_axis((0, k_w), name="rw") - - padded = topi.nn.pad(inputs, [0, 0, padding, padding]) - output = te.compute( - (batch_size, out_channel, out_h, out_w), - lambda n, co, h, w: te.sum( - (padded[n, co // out_channel_per_group * channel_per_group + rc, - h * stride + rh * dilation, w * stride + rw * dilation] - * weight[co, rc, rh, rw]), axis=[rc, rh, rw] - ), - name='conv2d_nchw' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, D, H, W, CI)) - weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) - batch_size, in_d, in_h, in_w, in_channel = inputs.shape - k_d, k_h, k_w, channel_per_group, out_channel = weight.shape - out_channel_per_group = out_channel // groups - - out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rd = te.reduce_axis((0, k_d), name='rd') - rh = te.reduce_axis((0, k_h), name='rh') - rw = te.reduce_axis((0, k_w), name='rw') - rc = te.reduce_axis((0, channel_per_group), name='rc') - - padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) - output = te.compute( - (batch_size, out_d, out_h, out_w, out_channel), - lambda n, d, h, w, co: te.sum( - (padded[n, d * stride + rd * dilation, - h * stride + rh * dilation, w * stride + rw * dilation, - co // out_channel_per_group * channel_per_group + rc] - * weight[rd, rh, rw, rc, co]), - axis=[rd, rh, rw, rc] - ), - name='conv3d_ndhwc' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): - inputs = te.placeholder((N, H, W, C)) - weight = te.placeholder((factor, kernel_size, kernel_size, C)) - - batch_size, in_h, in_w, in_channel = inputs.shape - factor, k_h, k_w, in_channel = weight.shape - out_channel = in_channel * factor - - assert factor.value == 1, "Not optimized for factor != 1" - - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rh = te.reduce_axis((0, k_h), name='rh') - rw = te.reduce_axis((0, k_w), name='rw') - - padded = topi.nn.pad(inputs, [0, padding, padding, 0]) - output = te.compute( - (batch_size, out_h, out_w, out_channel), - lambda n, h, w, c: te.sum( - (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, c // factor] - * weight[c % factor, rh, rw, c // factor]), - axis=[rh, rw] - ), - name="depth_conv2d_nhwc" - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): - inputs = te.placeholder((N, H, W, CI), name='inputs') - weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') - - batch, in_h, in_w, in_c = inputs.shape - filter_h, filter_w, in_c, out_c = weight.shape - stride_h, stride_w = (stride, stride) - - # compute padding - fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(padding, (filter_h, filter_w)) - bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom - bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right - - # padding stage - padded = topi.nn.pad(inputs, - [0, (bpad_top + stride_h - 1) // stride_h, - (bpad_left + stride_w - 1) // stride_w, 0], - [0, (bpad_bottom + stride_h - 1) // stride_h, - (bpad_right + stride_w - 1) // stride_w, 0]) - - # remove extra padding introduced by dilatation - idxdiv = te.indexdiv - idxmod = te.indexmod - border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) - border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) - - # dilation stage - strides = [1, stride_h, stride_w, 1] - n = len(padded.shape) - - # We should embed this dilation directly into te.compute rather than creating a new te.compute. - # Only in this way can we use unroll to eliminate the multiplication of zeros. - def _dilate(*indices): - not_zero = [] - index_tuple = [] - for i in range(n): - if not strides[i] == 1: - index_tuple.append(idxdiv(indices[i], strides[i])) - not_zero.append(idxmod(indices[i], strides[i]).equal(0)) - else: - index_tuple.append(indices[i]) - if not_zero: - not_zero = te.all(*not_zero) - return te.if_then_else(not_zero, padded(*index_tuple), tvm.tir.const(0.0, padded.dtype)) - return padded(*index_tuple) - - # convolution stage - out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h - out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w - rc = te.reduce_axis((0, in_c), name='rc') - rh = te.reduce_axis((0, filter_h), name='rh') - rw = te.reduce_axis((0, filter_w), name='rw') - - output = te.compute( - (batch, out_h, out_w, out_c), - lambda n, h, w, co: te.sum( - _dilate(n, h + rh + border_h, w + rw + border_w, rc) * - weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], - axis=[rh, rw, rc]), - name="conv2d_transpose_nhwc", - attrs={"ansor_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) - # todo(lmzheng): add constraints on the tile size of h and w - - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): - inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') - weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') - batch_size, in_h, in_w, _, _, in_channel = inputs.shape - k_h, k_w, _, _, _, out_channel = weight.shape - - out_h = (in_h + 2 * padding - kernel_size) // stride + 1 - out_w = (in_w + 2 * padding - kernel_size) // stride + 1 - - rh = te.reduce_axis((0, k_h), name="rh") - rw = te.reduce_axis((0, k_w), name="rw") - cap_k = te.reduce_axis((0, capsule_size), name='cap_k') - rc = te.reduce_axis((0, in_channel), name="rc") - - padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) - output = te.compute( - (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), - lambda n, h, w, cap_i, cap_j, co: te.sum( - (padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] - * weight[rh, rw, cap_k, cap_j, rc, co]), axis=[rh, rw, cap_k, rc] - ), - name='conv2d_capsule_nhwijc' - ) - return [inputs, weight, output] - - -@ansor.register_workload_func -def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): - # TODO: implement tile_size - tile_size = 4 #_infer_tile_size(data, kernel) - inputs = te.placeholder((N, H, W, CI), name='inputs') - #weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') - N, H, W, CI = get_const_tuple(inputs.shape) - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - # if dilation_h != 1 or dilation_w != 1: - # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) - KH = KW = kernel_size - HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) - HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride - assert HSTR == 1 and WSTR == 1 and KH == KW - - data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") - - r = KW - m = tile_size - alpha = m + r - 1 - A, B, G = winograd_transform_matrices(m, r, 'float32') - - H = (H + 2 * HPAD - KH) // HSTR + 1 - W = (W + 2 * WPAD - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW - r_kh = te.reduce_axis((0, KH), name='r_kh') - r_kw = te.reduce_axis((0, KW), name='r_kw') - # kernel_pack = te.compute((alpha, alpha, CO, CI), lambda eps, nu, co, ci: - # weight[0][0][0][0], - # name='kernel_pack') - kshape = (alpha, alpha, CO, CI) - kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") - - idxdiv = te.indexdiv - idxmod = te.indexmod - # pack input tile - input_tile = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: - data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps] - [idxmod(p, nW) * m + nu][ci], name='input_tile',) - - # transform data - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: - te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], - axis=[r_a, r_b]), name='data_pack', - attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "ansor_last_split_is_one": ["ci", "p"], - "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], - "ansor_no_cache_write": "True", - }) - - # do batch gemm - ci = te.reduce_axis((0, CI), name='ci') - bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: - te.sum(data_pack[eps][nu][p][ci] * - kernel_pack[eps][nu][co][ci], - axis=[ci]), name='bgemm') - - # inverse transform - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: - te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], - axis=[r_a, r_b]), name='inverse', - attrs={"ansor_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], - "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], - "ansor_last_split_is_one": ["co", "p"], - "ansor_no_cache_write": "True", - }) - - # output - output = te.compute((N, H, W, CO), lambda n, h, w, co: - inverse[idxmod(h, m), - idxmod(w, m), - n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), - co], - name='conv2d_winograd', - tag='conv2d_winograd_nhwc', - attrs={"ansor_no_split_at_outer": ["n", "h", "w", "co"],}) - return [inputs, kernel_pack, output] - -@ansor.register_workload_func -def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): - # TODO: implement tile_size - tile_size = 4 #_infer_tile_size(data, kernel) - inputs = te.placeholder((N, CI, H, W), name='inputs') - #weight = te.placeholder((CO, CI, kernel_size, kernel_size), name='weight') - N, CI, H, W = get_const_tuple(inputs.shape) - # if isinstance(dilation, int): - # dilation_h = dilation_w = dilation - # else: - # dilation_h, dilation_w = dilation - # if dilation_h != 1 or dilation_w != 1: - # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) - KH = KW = kernel_size - HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) - HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride - assert HSTR == 1 and WSTR == 1 and KH == KW - - data_pad = topi.nn.pad(inputs, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") - - r = KW - m = tile_size - alpha = m + r - 1 - A, B, G = winograd_transform_matrices(m, r, 'float32') - - H = (H + 2 * HPAD - KH) // HSTR + 1 - W = (W + 2 * WPAD - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW - r_kh = te.reduce_axis((0, KH), name='r_kh') - r_kw = te.reduce_axis((0, KW), name='r_kw') - # kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: - # weight[0][0][0][0], - # name='kernel_pack') - kshape = (alpha, alpha, CI, CO) - kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") - - idxdiv = te.indexdiv - idxmod = te.indexmod - # pack input tile - input_tile = te.compute((CI, P, alpha, alpha), lambda ci, p, eps, nu: - data_pad[idxdiv(p, (nH * nW))][ci][idxmod(idxdiv(p, nW), nH) * m + eps] - [idxmod(p, nW) * m + nu], name='input_tile') - - # transform data - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: - te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], - axis=[r_a, r_b]), name='data_pack', - attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "ansor_no_split_at_outer": ["ci", "p"], - "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], - "ansor_no_cache_write": "True", - }) - - # do batch gemm - ci = te.reduce_axis((0, CI), name='ci') - bgemm = te.compute((alpha, alpha, CO, P), lambda eps, nu, co, p: - te.sum(data_pack[eps][nu][ci][p] * - kernel_pack[eps][nu][ci][co], - axis=[ci]), name='bgemm') - - # inverse transform - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: - te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], - axis=[r_a, r_b]), name='inverse', - attrs={"ansor_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], - "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], - "ansor_no_cache_write": "True"}) - - # output - output = te.compute((N, CO, H, W), lambda n, co, h, w: - inverse[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), - idxmod(h, m), - idxmod(w, m)], - name='conv2d_winograd', - attrs={"ansor_no_split_at_outer": ["n", "co", "h", "w"],}) - return [inputs, kernel_pack, output] - -# ========================== Subgraphs ========================== - -@ansor.register_workload_func -def transpose_batch_matmul(batch, seq_len, n_head, n_dim): - query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') - value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') - query_T = te.compute((batch, n_head, seq_len, n_dim), - lambda b, h, l, d: query[b, l, h, d], name="query_T") - value_T = te.compute((batch, n_head, n_dim, seq_len), - lambda b, h, d, l: value[b, l, h, d], name="value_T") - k = te.reduce_axis((0, n_dim), name='k') - out = te.compute((batch, n_head, seq_len, seq_len), - lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), - name='C') - return [query, value, out] - -# ========================== Tune function & Task dicts ========================== - -def tune_wkl(task_func_dict, shape_dict, wkl_type, args): - target = tvm.target.create(args.target) - - for wkl_meta_name, func in task_func_dict.items(): - if not args.wkl in ["all", wkl_type, wkl_meta_name]: - continue - - log_file = args.log_file or wkl_meta_name + ".json" - wkl_keys = [] - for shape in shape_dict[wkl_meta_name]: - if shape[0] == 1: - shape = list(shape) - shape[0] = args.batch_size - - wkl_key = ansor.make_workload_key_func(func, shape) - wkl_keys.append(wkl_key) - if args.fast_check: - break - - if not args.tune: - cost, gflops = replay_workload( - wkl_key, target, args.target_host, log_file, - args.local_measure, args.rpc_device_key, args.rpc_host, - args.rpc_port, args.rpc_num_threads, args.ndk_cc, False) - # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', - # workload_name, "AutoSchedule", "default", - # {"costs": [cost]}, time.time()), args.out_file) - - if args.tune: - print("========== Tune for %s (%d shapes) ========== " % (wkl_meta_name, len(wkl_keys))) - - load_log_file = args.load_log or log_file - n_trials = args.n_trials_per_shape * len(wkl_keys) - - tune_option, measure_ctx = create_tune_option(target, log_file, - n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.rpc_device_key, args.rpc_host, args.rpc_port, - args.rpc_num_threads, args.ndk_cc) - - # tune workloads jointly using JointTuner - tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, - target, args.target_host, args.policy, args.model_type, - args.load_model, load_log_file, tune_option) - - if measure_ctx: - del measure_ctx - - -single_op_task_func_dict = { - 'GMM': batch_matmul_nkkm, - 'C1D': conv1d_nlc, - 'C2D': conv2d_nhwc, - 'C3D': conv3d_ndhwc, - 'GRP': conv2d_nhwc, - 'DIL': conv2d_nhwc, - 'DEP': depthwise_conv2d_nhwc, - 'T2D': conv2d_transpose_nhwc, - 'CAP': conv2d_capsule_nhwijc, - 'NRM': norm_bmn, - #'SMX': softmax_mn, - -# The following workloads are not in our sinle op evaluation plan. -# They should be moved to `common.py` and be used by `tune_wkl.py`. -# 'C2D_NCHW': conv2d_nchw, -# 'C2DWG_NHWC': conv2d_winograd_nhwc, -# 'C2DWG_NCHW': conv2d_winograd_nchw, -# 'GMM_TC': matmul_nkkm, -} - -subgraph_task_func_dict = { - 'conv2d_bn_relu': conv2d_nhwc_bn_relu, - #'conv2d_bn_relu': conv2d_nchw_bn_relu, # some old log uses conv2d_nchw_bn_relu - 'transpose_batch_matmul': transpose_batch_matmul, -} - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Search task related arguments - parser.add_argument("--wkl", type=str, required=True, - help="all - Tune all workloads; \ - op - Tune all single ops; \ - subgraph - Tune all subgraphs; \ - specific wkl name - Tune a specific workload") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') - parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--fast-check", action='store_true', - help='Only run one shape for each workload. This is used for fast checking') - - # Search strategy related arguments - parser.add_argument("--n-trials-per-shape", type=int, default=1000) - parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') - parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='round-robin', - choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') - parser.add_argument("--seed", type=int, default=0, help='random seed') - - # Log file related arguments - parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") - parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") - parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - - # Measurement related and other arguments - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) - parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--rpc-device-key", type=str, default=None) - parser.add_argument("--rpc-host", type=str, default='0.0.0.0') - parser.add_argument("--rpc-port", type=int, default=9190) - parser.add_argument("--rpc-num-threads", type=int, default=None) - parser.add_argument("--n-parallel", type=int, default=1) - parser.add_argument("--ndk-cc", type=str, default=None) - args = parser.parse_args() - - np.random.seed(args.seed) - random.seed(args.seed) - logging.basicConfig() - logging.getLogger('ansor').setLevel(logging.DEBUG) - - # compute the number of tasks - num_tasks = 0 - for wkl_meta_name in single_op_task_func_dict: - if not args.wkl in ["all", "op", wkl_meta_name]: - continue - if args.fast_check: - num_tasks += 1 - else: - num_tasks += len(single_op_shape_dict[wkl_meta_name]) - for wkl_meta_name in subgraph_task_func_dict: - if not args.wkl in ["all", "subgraph", wkl_meta_name]: - continue - if args.fast_check: - num_tasks += 1 - else: - num_tasks += len(subgraph_shape_dict[wkl_meta_name]) - print("Number of tasks: %d\tTotal trials: %d" % (num_tasks, num_tasks * args.n_trials_per_shape)) - - # tune for tasks - tune_wkl(single_op_task_func_dict, single_op_shape_dict, "op", args) - tune_wkl(subgraph_task_func_dict, subgraph_shape_dict, "subgraph", args) diff --git a/scripts/tune_test.py b/scripts/tune_test.py deleted file mode 100644 index 6b39cf5e7865..000000000000 --- a/scripts/tune_test.py +++ /dev/null @@ -1,394 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Use auto scheduler to tune workloads""" -import argparse -import logging -import os -import random - -import numpy as np - -import tvm -from tvm import ansor, te -from tvm.ansor.utils import request_remote - -from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool - -def tensor_core_meet_condition(meta_policy, state, stage_id): - pass - -def intrin_wmma_load_matrix(scope): - n = 16 - A = te.placeholder((n, n), name='A', dtype='float16') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) - C = te.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - - BA = ins[0] - BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, n, n, n, BC.elem_offset // 256, - BA.access_ptr('r'), n, 'row_major')) - return ib.get() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) - -@tvm._ffi.register_func -def intrin_wmma_load_matrix_a(): - return intrin_wmma_load_matrix("wmma.matrix_a") - -@tvm._ffi.register_func -def intrin_wmma_load_matrix_b(): - return intrin_wmma_load_matrix("wmma.matrix_b") - -@tvm._ffi.register_func -def intrin_wmma_gemm(): - n = 16 - A = te.placeholder((n, n), name='A', dtype='float16') - B = te.placeholder((n, n), name='B', dtype='float16') - k = te.reduce_axis((0, n), name="k") - C = te.compute((n, n), - lambda ii, jj: - te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), - name='C') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) - BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) - BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) - - def intrin_func(ins, outs): - BA, BB = ins - BC, = outs - - def init(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) - return ib.get() - - def update(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', - BC.data, BC.elem_offset // 256, - BA.data, BA.elem_offset // 256, - BB.data, BB.elem_offset // 256, - BC.data, BC.elem_offset // 256)) - return ib.get() - - return update(), init(), update() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) - -@tvm._ffi.register_func -def intrin_wmma_store_matrix(): - n = 16 - A = te.placeholder((n, n), name='A', dtype='float32') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) - C = te.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - BA = ins[0] - BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, n, n, n, BA.elem_offset // 256, - BC.access_ptr('w'), n, 'row_major')) - return ib.get() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) - -def tensor_core_apply(meta_policy, state, stage_id): - ret = [] - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - - A, B, C = meta_policy.cur_task.compute_dag.ops - - C_local = state.cache_write(C, "wmma.accumulator") - - its0 = state.split(C_local, state[C_local].iters[0], [None, None]) - split_step0 = state.transform_steps_size() - 1 - its1 = state.split(C_local, state[C_local].iters[3], [None, None]) - split_step1 = state.transform_steps_size() - 1 - its2 = state.split(C_local, state[C_local].iters[8], [None]) - - state.reorder(C_local, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its2[0], its2[1], - state[C_local].iters[6], - state[C_local].iters[7], - state[C_local].iters[10]]) - state.fuse(C_local, [state[C_local].iters[0], state[C_local].iters[1]]) - state.fuse(C_local, [state[C_local].iters[1], state[C_local].iters[2]]) - state.fuse(C_local, [state[C_local].iters[2], state[C_local].iters[3]]) - - its0 = state.follow_split(C, state[C].iters[0], split_step0, 2) - its1 = state.follow_split(C, state[C].iters[3], split_step1, 2) - state.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - state[C].iters[6], state[C].iters[7]]) - state.fuse(C, [state[C].iters[0], state[C].iters[1]]) - state.fuse(C, [state[C].iters[1], state[C].iters[2]]) - local_write_pos = state.fuse(C, [state[C].iters[2], state[C].iters[3]]) - state.compute_at(C_local, C, local_write_pos) - shared_read_pos = state[C_local].iters[3] - local_read_pos = state[C_local].iters[4] - state.bind_thread(C, state[C].iters[0], "blockIdx.x") - state.bind_thread(C, state[C].iters[1], "vthread") - state.bind_thread(C, state[C].iters[2], "threadIdx.x") - - B_shared = state.cache_read(B, "shared", [C_local]) - B_local = state.cache_read(B_shared, "wmma.matrix_b", [C_local]) - state.compute_at(B_shared, C_local, shared_read_pos) - state.compute_at(B_local, C_local, local_read_pos) - - it = state.fuse(B_shared, state[B_shared].iters[:]) - its = state.split(B_shared, it, [4]) # vectorize add a callback check function - state.vectorize(B_shared, its[1]) - its = state.follow_fused_split(B_shared, its[0], [split_step0, split_step1], 1, True) - state.bind_thread(B_shared, its[1], "threadIdx.x") - - A_shared = state.cache_read(A, "shared", [C_local]) - A_local = state.cache_read(A_shared, "wmma.matrix_a", [C_local]) - state.compute_at(A_shared, C_local, shared_read_pos) - state.compute_at(A_local, C_local, local_read_pos) - - it = state.fuse(A_shared, state[A_shared].iters[:]) - its = state.split(A_shared, it, [4]) # vectorize add a callback check function - state.vectorize(A_shared, its[1]) - its = state.follow_fused_split(A_shared, its[0], [split_step0, split_step1], 1, True) - state.bind_thread(A_shared, its[1], "threadIdx.x") - - state.tensorize(A_local, state[A_local].iters[-2], "intrin_wmma_load_matrix_a") - state.tensorize(B_local, state[B_local].iters[-2], "intrin_wmma_load_matrix_b") - state.tensorize(C_local, state[C_local].iters[-3], "intrin_wmma_gemm") - state.tensorize(C, state[C].iters[-2], "intrin_wmma_store_matrix") - - print(state) - - ret.append([state.state_object, -1]) - return ret - -def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, - n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, - rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10, - tensor_core_matmul=False): - builder = runner = measure_ctx = None - if local_measure: - builder = ansor.LocalBuilder(timeout=build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" - runner = ansor.LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=rpc_device_key, host=rpc_host, port=rpc_port, - timeout=run_timeout, n_parallel=n_parallel, - repeat=1, min_repeat_ms=200) - remote = request_remote(rpc_device_key, rpc_host, rpc_port) - if rpc_num_threads: - config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, rpc_num_threads) - - pre_search_callbacks = [ansor.PreloadMeasuredStates(log_file)] - if tensor_core_matmul: - pre_search_callbacks.append(ansor.PreloadCustomSketchRule(tensor_core_meet_condition, tensor_core_apply)) - tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, - num_measure_per_iter=num_measure_per_iter, - verbose=verbose, - builder=builder, - runner=runner, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=pre_search_callbacks) - - return tune_option, measure_ctx - - -def replay_workload(wkl_key, target, target_host, log_file, - local_measure=True, rpc_device_key=None, rpc_host="0.0.0.0", - rpc_port=9190, rpc_num_threads=None, ndk_cc=None, - show_lower_result=True): - cost = gflops = None - - inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) - if inp is None: - print("Cannot find log for: %s" % wkl_key) - else: - dag = ansor.workload_key_to_dag(inp.task.workload_key) - print("Found schedule for: %s" % wkl_key) - - s, bufs = dag.apply_steps_from_state(inp.state) - if show_lower_result: - print(tvm.lower(s, bufs, simple_mode=True)) - - if local_measure: - remote = None - else: - remote = request_remote(rpc_device_key, rpc_host, rpc_port) - if rpc_num_threads: - config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, rpc_num_threads) - - cost = np.mean((measure_schedule(s, bufs, target, target_host, - remote=remote, ndk_cc=ndk_cc))) - gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % (gflops, cost * 1e3)) - - return cost, gflops - - -def tune_workload(wkl_key, target, target_host, policy, model_type, - load_model_file, load_log_file, tune_option): - """Tune a workload""" - - if False: - # Debug info. Print static analysis results from the access analyzer - dag = ansor.workload_key_to_dag(wkl_key) - print(dag.access_analyzer) - exit() - - if model_type == 'xgb': - model = ansor.XGBModel() - if load_model_file: - print("Load pretrained model...") - model.load(load_model_file) - elif load_log_file: - model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() - else: - raise ValueError("Invalid model: " + model_type) - - if policy == 'sketch': - policy = ansor.SketchSearchPolicy(program_cost_model=model) - elif policy == 'beam-search': - policy = ansor.SketchSearchPolicy(program_cost_model=model, - params={'use_beam_search': 1}) - else: - raise ValueError("Invalid search policy: " + policy) - - s, bufs = ansor.auto_schedule(wkl_key, - target=target, target_host=target_host, - search_policy=policy, - tune_option=tune_option) - -def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, - search_policy, model_type, load_model_file, load_log_file, - tune_option): - """Tune for multiple workloads together with TaksScheduler""" - tasks = [] - for wkl_key in wkl_keys: - dag = ansor.workload_key_to_dag(wkl_key) - tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) - - def objective_func(costs): - return sum(c * w for c, w in zip(costs, weights)) - - tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, - load_log_file=load_log_file, load_model_file=load_model_file) - search_policy = "%s.%s" % (search_policy, model_type) - tuner.tune(tune_option, search_policy) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Search task related arguments - parser.add_argument("--wkl", type=str, required=True) - parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') - parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - - # Search strategy related arguments - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') - parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='no', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') - parser.add_argument("--seed", type=int, default=0, help='random seed') - - # Log file related arguments - parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") - parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") - parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - - # Measurement related and other arguments - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) - parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--rpc-device-key", type=str, default=None) - parser.add_argument("--rpc-host", type=str, default='0.0.0.0') - parser.add_argument("--rpc-port", type=int, default=9190) - parser.add_argument("--rpc-num-threads", type=int, default=None) - parser.add_argument("--n-parallel", type=int, default=1) - parser.add_argument("--ndk-cc", type=str, default=None) - args = parser.parse_args() - - np.random.seed(args.seed) - random.seed(args.seed) - logging.basicConfig() - logging.getLogger('ansor').setLevel(logging.DEBUG) - - wkl_keys = get_workload_keys(args.wkl) - target = tvm.target.create(args.target) - log_file = args.log_file or args.wkl + ".json" - - # Tune workloads - if args.tune: - load_log_file = args.load_log or log_file - weights = get_workload_weights(args.wkl) - - # Special check for tensor core - wkl_key = args.wkl - wkl_key = wkl_key.split("-") - tensor_core_matmul = False - if wkl_key[0] == "matmul" and wkl_key[6] == "tc": - tensor_core_matmul = True - - tune_option, measure_ctx = create_tune_option(target, log_file, - args.n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, - args.ndk_cc, tensor_core_matmul=tensor_core_matmul) - - if args.task_scheduler == 'no': - # tune workloads one by one - for wkl_key in wkl_keys: - tune_workload(wkl_key, target, args.target_host, args.policy, - args.model_type, args.load_model, load_log_file, - tune_option) - else: - # tune workloads jointly with TaskScheduler - tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, - target, args.target_host, args.policy, - args.model_type, args.load_model, load_log_file, - tune_option) - if measure_ctx: - del measure_ctx - - # Replay the best found schedule - if len(wkl_keys) == 1 or not args.tune: - for wkl_key in wkl_keys: - replay_workload(wkl_key, target, args.target_host, log_file, - args.local_measure, args.rpc_device_key, args.rpc_host, - args.rpc_port, args.rpc_num_threads, args.ndk_cc) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 9e6da6ff6f3b..d7af8b94729a 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -38,7 +38,6 @@ #include #include "transform_step.h" #include "search_policy/utils.h" -#include "../relay/transforms/kernel_layout_transform.h" namespace tvm { namespace ansor { @@ -737,7 +736,7 @@ void ComputeDAG::RewriteLayout( CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); std::string ori_layout = os.str(); os.str(""); - ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); + // ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); } } @@ -800,7 +799,7 @@ void ComputeDAG::RewriteLayout( } std::string new_layout = os.str(); os.str(""); - ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); + // ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); placeholder_new_names[placeholder_op] = new_names; placeholder_new_shapes[placeholder_op] = new_shape; diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 17ab73efb6aa..6be4773fe780 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -52,65 +52,6 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( if (target->target_name == "llvm") { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 32, 64, 16, 64); - } else if (target->device_type == kDLGPU) { - // TODO(jcf94): temp implementation, max vectorize size in GPU is related - // to the data type - auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); - auto* p_hardware_params = hardware_params.CopyOnWrite(); - - auto ctx = TVMContext{kDLGPU, 0}; - auto func = tvm::runtime::Registry::Get("device_api.gpu"); - CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = - static_cast(((*func)()).operator void*()); - - tvm::runtime::TVMRetValue ret; - device_api->GetAttr( - ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); - p_hardware_params->max_shared_memory_per_block = ret; - - device_api->GetAttr( - ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); - p_hardware_params->max_registers_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, - &ret); - p_hardware_params->max_threads_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); - p_hardware_params->warp_size = ret; - - // Manually set now - p_hardware_params->max_vthread_extent = 4; - - return hardware_params; - } else if (target->device_type == kDLOpenCL) { - // TODO(jcf94): temp implementation - auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); - auto p_hardware_params = hardware_params.CopyOnWrite(); - - auto ctx = TVMContext{kDLOpenCL, 0}; - auto func = tvm::runtime::Registry::Get("device_api.opencl"); - CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = - static_cast(((*func)()).operator void*()); - - tvm::runtime::TVMRetValue ret; - device_api->GetAttr( - ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); - p_hardware_params->max_shared_memory_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, - &ret); - p_hardware_params->max_threads_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); - p_hardware_params->warp_size = ret; - - // Manually set now - p_hardware_params->max_vthread_extent = 4; - - return hardware_params; } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d3af64a4f576..4887ef0ee47d 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -132,13 +132,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); - if ((x + broadcast(y, lanes)).Match(ret)) { - if (auto ps = y.Eval().as()) { - if (ps->value == 0.0) { - return x.Eval(); - } - } - } } if (IsIndexType(op->dtype)) { @@ -429,13 +422,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); - if ((broadcast(x, lanes) * y).Match(ret)) { - if (auto ps = x.Eval().as()) { - if (ps->value == 0.0) { - return make_const(op->dtype, 0.0); - } - } - } } if (IsIndexType(op->dtype)) { @@ -714,9 +700,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar w, x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3, c4; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -781,11 +767,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x * c1, c2), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); - // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -802,13 +783,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y * c2 + z, c3), floordiv(x * c1 + y * c2, c3), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && c3.Eval()->value > 0 && - c3.Eval()->value % c1.Eval()->value == 0 && - c3.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), - std::max(-c1.Eval()->value, -c2.Eval()->value) + 1)); - TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -833,18 +807,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); - - // Rules involving 4-operands - TVM_TRY_REWRITE_IF(floordiv(w * c1 + x * c2 + y * c3 + z, c4), - floordiv(w * c1 + x * c2 + y * c3, c4), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c3.Eval()->value > 0 && c4.Eval()->value > 0 && - c4.Eval()->value % c1.Eval()->value == 0 && - c4.Eval()->value % c2.Eval()->value == 0 && - c4.Eval()->value % c3.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), - std::max(-c1.Eval()->value, - std::max(-c2.Eval()->value, -c3.Eval()->value)) + 1)); } return ret; } @@ -856,9 +818,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar w, x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3, c4; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -902,31 +864,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); - - // TODO(jcf94): For the next three rules, better use the max common factor - // of c1, c2, c3 to do the simplify - TVM_TRY_REWRITE_IF(floormod(x * c1 + y * c2 + z, c3), - floormod(x * floordiv(c1, c2) + y, floordiv(c3, c2)) * c2 + z, - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c3.Eval()->value > 0 && - c3.Eval()->value % c2.Eval()->value == 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), -c2.Eval()->value + 1)); - - TVM_TRY_REWRITE_IF(floormod(w * c1 + x * c2 + y * c3 + z, c4), - floormod(w * floordiv(c1, c3) + x * floordiv(c2, c3) + y, - floordiv(c4, c3)) * c3 + z, - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c3.Eval()->value > 0 && c4.Eval()->value > 0 && - c4.Eval()->value % c3.Eval()->value == 0 && - c1.Eval()->value % c3.Eval()->value == 0 && - c2.Eval()->value % c3.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), -c3.Eval()->value + 1)); - // try modular analysis if (floormod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 5b063eca4337..a192002825e6 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -219,7 +219,6 @@ class TypeSolver::Unifier : public TypeFunctor { return Type(nullptr); } - tt1 = tt2; tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index a8cd1d3c2462..34c3487e3ef2 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -153,11 +153,6 @@ class RelayBuildModule : public runtime::ModuleNode { CHECK_EQ(args.num_args, 2); *rv = this->Optimize(args[0], args[1], this->params_); }); - } else if (name == "call_all_topi_funcs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue *rv) { - CHECK_EQ(args.num_args, 3); - this->CallAllTopiFuncs(args[0], args[1], args[2]); - }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -232,21 +227,6 @@ class RelayBuildModule : public runtime::ModuleNode { BuildRelay(mod, params_); } - /*! \brief Call all used TOPI compute and schedule in a relay function */ - void CallAllTopiFuncs(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { - targets_ = targets; - target_host_ = target_host; - - IRModule relay_module = Optimize(mod, targets_, params_); - auto func = Downcast(relay_module->Lookup("main")); - - graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, targets_); - graph_codegen_->Codegen(func); - } - protected: /*! * \brief Optimize a Relay IRModule. @@ -335,18 +315,6 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); - - if (targets.size() == 1) { - pass_seqs.push_back(transform::KernelLayoutTransform()); - pass_seqs.push_back(transform::DeFuseOps()); - pass_seqs.push_back(transform::FoldConstant()); - transform::Pass seq = transform::Sequential(pass_seqs); - const auto& it = targets.begin(); - With tctx((*it).second); - relay_module = seq(relay_module); - relay_module = transform::FuseOps()(relay_module); - } - relay_module = transform::InferType()(relay_module); // Inline the functions that have been lifted by the module scope. // diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index fde880b10f1d..2aae8546248f 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -68,11 +68,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); - n->disabled = false; - char* envar = getenv("TVM_RELAY_DISABLE_BUILD_CACHE"); - if (envar != nullptr && strcmp(envar, "true") == 0) { - n->disabled = true; - } data_ = std::move(n); } diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index b290462a4b22..a5f3f6359f89 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -115,8 +115,6 @@ class CCacheKeyNode : public Object { /*! \brief The hardware target.*/ Target target; - bool disabled; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); @@ -261,7 +259,6 @@ inline size_t CCacheKeyNode::Hash() const { } inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { - if (disabled) return false; if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && tvm::StructuralEqual()(this->source_func, other->source_func); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 30269b85795f..ee5e291e3d53 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2455,60 +2455,6 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); -// relay.kernel_layout_transform -TVM_REGISTER_NODE_TYPE(KernelLayoutTransformAttrs); - -Array KernelLayoutTransformCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - return Array{ - topi::kernel_layout_transform(inputs[0], param->src_layout, param->dst_layout) - }; -} - -bool KernelLayoutTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - const auto* data = types[0].as(); - CHECK(data != nullptr); - const KernelLayoutTransformAttrs* params = attrs.as(); - - Array dst_shape; - std::vector dst_axes; - - topi::parse_kernel_layout(params->dst_layout, &dst_shape, &dst_axes); - - reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); - return true; -} - -Expr MakeKernelLayoutTransform(Expr data, - String src_layout, - String dst_layout) { - auto attrs = make_object(); - attrs->src_layout = std::move(src_layout); - attrs->dst_layout = std::move(dst_layout); - static const Op& op = Op::Get("kernel_layout_transform"); - return Call(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.kernel_layout_transform") -.set_body_typed(MakeKernelLayoutTransform); - -RELAY_REGISTER_OP("kernel_layout_transform") - .describe(R"code(Transform the input kernel layout. -)code" TVM_ADD_FILELINE) - .set_attrs_type() - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") - .add_type_rel("kernel_layout_transform", KernelLayoutTransformRel) - .set_support_level(5) - .set_attr("FTVMCompute", KernelLayoutTransformCompute); - - /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc deleted file mode 100644 index 1a108fb08888..000000000000 --- a/src/relay/transforms/defuse_ops.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "pattern_util.h" - -namespace tvm { -namespace relay { - -class DefuseOpsMutator : public ExprMutator { - public: - class FuncBodyMutator : public ExprMutator { - public: - Array args_; - - explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } - - Expr VisitExpr_(const VarNode* n) { - const std::string& name = n->name_hint(); - CHECK_EQ(name[0], 'p'); - std::string id_str = name.substr(1); - int id = atoi(id_str.c_str()); - CHECK(id >= 0 && size_t(id) < args_.size()); - return args_[id]; - } - }; - - Expr VisitExpr_(const CallNode* n) { - auto new_n = ExprMutator::VisitExpr_(n); - - const auto* call = new_n.as(); - if (call) { - const auto* func = call->op.as(); - if (func) { - const auto& func_call = func->body.as(); - if (func_call) { - return FuncBodyMutator(call->args).Mutate(func->body); - } - } - } - return new_n; - } -}; - -Expr DeFuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); } - -namespace transform { - -Pass DeFuseOps() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::DeFuseOps(f)); - }; - return CreateFunctionPass(pass_func, 3, "DeFuseOps", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps").set_body_typed(DeFuseOps); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.cc b/src/relay/transforms/kernel_layout_transform.cc deleted file mode 100644 index 421968b8a6b9..000000000000 --- a/src/relay/transforms/kernel_layout_transform.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "kernel_layout_transform.h" - -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace relay { - -// Todo: do not use global variables -std::deque KernelLayoutVisitor::global_ori_layouts_queue; -std::deque KernelLayoutVisitor::global_new_layouts_queue; - -Expr KernelLayoutTransform(const Expr& expr) { - KernelLayoutVisitor visitor; - - // Do a pre-order DFS to gather the optimal kernel layouts for all conv2d nodes. - // These layouts were written to global static variables in python function - // `prepare_layout_rewrite` - visitor.VisitExpr(expr); - - // Do a post-order DSF to mutate layout for all conv2d nodes - return KernelLayoutTransformer(&visitor).Mutate(expr); -} - -namespace transform { - -Pass KernelLayoutTransform() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::KernelLayoutTransform(f)); - }; - return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform").set_body_typed(KernelLayoutTransform); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h deleted file mode 100644 index c6c38fb71cf4..000000000000 --- a/src/relay/transforms/kernel_layout_transform.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ -#define TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ - -#include -#include - -#include -#include -#include -#include -#include - -#include "../../ansor/compute_dag.h" -#include "pattern_util.h" - -namespace tvm { -namespace relay { - -/*! \brief A visitor to gather the optimal kernel layout for all conv2d nodes. */ -class KernelLayoutVisitor : public ExprVisitor { - public: - void VisitExpr_(const CallNode* n) { - if (n && n->op.as() && - (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end()) && - n->args[1]->type_as()->shape[3].as()->value > 1 && - !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { - ori_layouts_map[n] = global_ori_layouts_queue.front(); - new_layouts_map[n] = global_new_layouts_queue.front(); - // std::cout << "ori_layout " << global_ori_layouts_queue.front() - // << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; - global_ori_layouts_queue.pop_front(); - global_new_layouts_queue.pop_front(); - } - ExprVisitor::VisitExpr_(n); - } - - std::unordered_map ori_layouts_map; - std::unordered_map new_layouts_map; - std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; - - static std::deque global_ori_layouts_queue; - static std::deque global_new_layouts_queue; -}; - -/*! \brief A mutator to rewrite kernel layout for all conv2d nodes */ -class KernelLayoutTransformer : public ExprMutator { - public: - explicit KernelLayoutTransformer(KernelLayoutVisitor* visitor) - : ExprMutator(), visitor_(visitor) {} - - Expr VisitExpr_(const CallNode* n) { - auto new_n = ExprMutator::VisitExpr_(n); - - const auto* call = new_n.as(); - std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; - if (call && call->op.as() && - (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end() && - n->args[1]->type_as()->shape[3].as()->value > 1)) { - auto ori_layout_iter = visitor_->ori_layouts_map.find(n); - auto new_layout_iter = visitor_->new_layouts_map.find(n); - if (ori_layout_iter != visitor_->ori_layouts_map.end() && - new_layout_iter != visitor_->new_layouts_map.end()) { - const std::string& ori_layout = ori_layout_iter->second; - const std::string& new_layout = new_layout_iter->second; - Expr updated_kernel = MakeKernelLayoutTransform(call->args[1], ori_layout, new_layout); - Array updated_args = {call->args[0], updated_kernel}; - new_n = Call(call->op, updated_args, call->attrs); - } - } - return new_n; - } - - private: - KernelLayoutVisitor* visitor_; -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index a9d3b5168e47..7518eb9ac81a 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -685,8 +685,6 @@ Expr MakeExpandDims(Expr data, int axis, int num_newaxis); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); -Expr MakeKernelLayoutTransform(Expr data, String src_layout, String dst_layout); - Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 4e71383cc1bb..a6d4a5499469 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -94,10 +94,6 @@ class CUDADeviceAPI final : public DeviceAPI { } case kGcnArch: return; - case kMaxRegistersPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id)); - break; - } } *rv = value; } diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 714535ecc8a6..800a9167dadc 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -26,9 +26,6 @@ #include #include -#include -#include - #include "runtime_base.h" extern "C" { @@ -183,8 +180,7 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, DLDataType dtype, - DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); @@ -194,59 +190,6 @@ NDArray NDArray::Empty(std::vector shape, DLDataType dtype, return ret; } - -NDArray NDArray::NonEmpty(std::vector shape, DLDataType dtype, - DLContext ctx) { - NDArray ret = Internal::Create(shape, dtype, ctx); - NDArray dummy_cpu_arr = Internal::Create(shape, dtype, {kDLCPU, 0}); - - // setup memory content - size_t size = GetDataSize(ret.get_mutable()->dl_tensor); - size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); - dummy_cpu_arr.get_mutable()->dl_tensor.data = - DeviceAPI::Get(dummy_cpu_arr->ctx)->AllocDataSpace( - {kDLCPU, 0}, size, alignment, dummy_cpu_arr->dtype); - size_t elem_cnt = 1; - for (tvm_index_t i = 0; i < dummy_cpu_arr->ndim; ++i) { - elem_cnt *= static_cast(dummy_cpu_arr->shape[i]); - } - - // TODO(..): maybe we could have better solution for assigning values - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(1.0, 10.0); - // Use float representation could make us work well on float / int type too. - for (size_t i = 0; i < elem_cnt; ++i) { - if (dummy_cpu_arr->dtype.bits == 1) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else if (dummy_cpu_arr->dtype.bits == 8) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else if (dummy_cpu_arr->dtype.bits == 16) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = - __truncXfYf2__( - static_cast(dis(gen))); - } else if (dummy_cpu_arr->dtype.bits == 32) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else if (dummy_cpu_arr->dtype.bits == 64) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else { - LOG(FATAL) << "Doesn't support dtype code " << dtype.code - << " dtype bits " << dtype.bits; - } - } - ret.get_mutable()->dl_tensor.data = - DeviceAPI::Get(ret->ctx)->AllocDataSpace( - ret->ctx, size, alignment, ret->dtype); - CopyFromTo(&(dummy_cpu_arr.get_mutable()->dl_tensor), - &(ret.get_mutable()->dl_tensor)); - return ret; -} - NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); // construct header @@ -314,9 +257,8 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, - int dtype_bits, int dtype_lanes, int device_type, - int device_id, TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -330,22 +272,6 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, API_END(); } -int TVMArrayAllocNonEmpty(const tvm_index_t* shape, int ndim, int dtype_code, - int dtype_bits, int dtype_lanes, int device_type, - int device_id, TVMArrayHandle* out) { - API_BEGIN(); - DLDataType dtype; - dtype.code = static_cast(dtype_code); - dtype.bits = static_cast(dtype_bits); - dtype.lanes = static_cast(dtype_lanes); - DLContext ctx; - ctx.device_type = static_cast(device_type); - ctx.device_id = device_id; - *out = NDArray::Internal::MoveToFFIHandle( - NDArray::NonEmpty(std::vector(shape, shape + ndim), dtype, ctx)); - API_END(); -} - int TVMArrayFree(TVMArrayHandle handle) { API_BEGIN(); NDArray::Internal::FFIDecRef(handle); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 71d3232ca4d5..6d9835e6231c 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -109,9 +109,6 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* } case kGcnArch: return; - default: { - LOG(WARNING) << "Attr not implemented."; - } } } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index d58130d700f4..89f3e7c6c7f8 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -24,14 +24,9 @@ #include #include -#include #include #include -#if defined(_M_X64) || defined(__x86_64__) -#include -#endif - #include "rpc_endpoint.h" #include "rpc_session.h" @@ -305,22 +300,6 @@ std::shared_ptr RPCModuleGetSession(Module mod) { return rmod->sess(); } -inline void CacheFlush(const char* p, unsigned int allocation_size) { -// TODO(FrozenGene): Support ARM. -#if (defined(_M_X64) || defined(__x86_64__)) - size_t cache_line = 64; - - if (p == nullptr || allocation_size <= 0) { - return; - } - - for (size_t i = 0; i < allocation_size; i += cache_line) { - _mm_clflush(static_cast(&p[i])); - } - -#endif -} - PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, int min_repeat_ms) { CHECK(pf != nullptr); @@ -334,21 +313,12 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repe auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { TVMRetValue temp; std::ostringstream os; - const char* cache_flush = std::getenv("TVM_AUTO_CACHE_FLUSH"); // skip first time call, to activate lazy compilation components. pf.CallPacked(args, &temp); DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); for (int i = 0; i < repeat; ++i) { - if (cache_flush && std::atoi(cache_flush) != 0) { - CHECK_EQ(number, 1); - // we want to keep input data - for (int j = 1; j < args.size(); j++) { - CacheFlush(reinterpret_cast(args[j].operator DLTensor*()->data), - GetDataSize(*(args[j].operator DLTensor*()))); - } - } std::chrono::time_point tbegin, tend; double duration_ms = 0.0; diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 3b1889aed8ef..e5520efe30a6 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -166,13 +166,8 @@ class ThreadGroup::Impl { #if defined(_M_X64) || defined(__x86_64__) big_count /= 2; // ignore hyper-threading #endif - const char* bind_master_core_0 = getenv("TVM_BIND_MASTER_CORE_0"); - if (bind_master_core_0 && atoi(bind_master_core_0) != 0) { - CPU_SET(sorted_order_[0], &cpuset); - } else { - for (int i = 0; i < big_count; ++i) { - CPU_SET(sorted_order_[i], &cpuset); - } + for (int i = 0; i < big_count; ++i) { + CPU_SET(sorted_order_[i], &cpuset); } } #if defined(__ANDROID__) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 04a3f0b25bee..af72d3b1a1df 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -461,7 +461,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); auto it = s->iter_var_attrs.find(iv); - // don't need to rebase path that are binded. + // don;t need to rebase path that are binded. if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } @@ -614,74 +614,10 @@ void InjectInline(ScheduleNode* sch) { } } -void LegalizeInvalidAttach(ScheduleNode* sch) { - std::unordered_map replace_map; - - for (Stage stage : sch->stages) { - for (Stage s = stage; s.defined();) { - Stage spec = s.GetAttachSpec(); - if (spec->attach_type != kScope) { - break; - } - bool start_attach = false; - IterVar attach_ivar = spec->attach_ivar; - s = spec->attach_stage; - CHECK(attach_ivar.defined()); - CHECK(s.defined()); - - for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { - IterVar iv = s->leaf_iter_vars[i - 1]; - if (!start_attach && iv.same_as(attach_ivar)) { - start_attach = true; - } - } - if (!start_attach) { - // If the attach_var is fused into another iter_var, update the - // attach_var to be the fused one - // Do this recursively. - IterVar new_attach_ivar = attach_ivar;; - bool updated = true; - while (updated) { - updated = false; - for (const auto& rel : s->relations) { - if (const FuseNode* r = rel.as()) { - if (new_attach_ivar.same_as(r->inner)) { - new_attach_ivar = r->fused; - updated = true; - } - } else if (const SplitNode* r = rel.as()) { - if (new_attach_ivar.same_as(r->parent)) { - new_attach_ivar = r->inner; - updated = true; - } - } - } - replace_map[attach_ivar] = new_attach_ivar; - } - } - } - } - - // remap the parent relation - for (Stage s : sch->stages) { - if (s->attach_type != kScope) continue; - if (replace_map.count(s->attach_ivar)) { - s->attach_ivar = replace_map.at(s->attach_ivar); - } - } - for (Stage s : sch->groups) { - if (s->attach_type != kScope) continue; - if (replace_map.count(s->attach_ivar)) { - s->attach_ivar = replace_map.at(s->attach_ivar); - } - } -} - Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); RebaseNonZeroMinLoop(sn); - LegalizeInvalidAttach(sn.operator->()); return sn; } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f6a8ad034aa5..1fbae0fd2dcd 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -33,22 +33,20 @@ namespace tvm { namespace tir { -class GPUCodeVerifier : public StmtExprVisitor { +class GPUCodeVerifier : public StmtVisitor { public: bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z, int64_t max_vector_bytes) { + int64_t max_thread_z) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); max_thread_x_ = static_cast(max_thread_x); max_thread_y_ = static_cast(max_thread_y); max_thread_z_ = static_cast(max_thread_z); - max_vector_bytes_ = static_cast(max_vector_bytes); Reset_(); - // TODO(jcf94): Add support of detecting CUDA Misaligned Address error this->VisitStmt(stmt); return valid_; @@ -64,10 +62,6 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } - - if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); - } } void VisitStmt_(const AttrStmtNode* op) final { @@ -135,18 +129,6 @@ class GPUCodeVerifier : public StmtExprVisitor { } } - void VisitExpr_(const LoadNode* op) { - // Currently not able to check: - // if the index expression failed to be simplified to a Ramp - if (op->index->IsInstance()) { - if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= - static_cast(max_vector_bytes_); - } - } - ExprVisitor::VisitExpr_(op); - } - private: int nest_level_{0}; @@ -164,7 +146,6 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t max_shared_memory_per_block_; size_t max_threads_per_block_; size_t max_thread_x_, max_thread_y_, max_thread_z_; - size_t max_vector_bytes_; bool valid_{true}; @@ -188,32 +169,27 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { int64_t max_thread_x = INT64_MAX; int64_t max_thread_y = INT64_MAX; int64_t max_thread_z = INT64_MAX; - int64_t max_vector_bytes = INT64_MAX; for (auto iter : constraints) { const IntImmNode* val = iter.second.as(); - if (iter.first == "max_local_memory_per_block") { + if (iter.first == "max_local_memory_per_block") max_local_memory_per_block = val->value; - } else if (iter.first == "max_shared_memory_per_block") { + else if (iter.first == "max_shared_memory_per_block") max_shared_memory_per_block = val->value; - } else if (iter.first == "max_threads_per_block") { + else if (iter.first == "max_threads_per_block") max_threads_per_block = val->value; - } else if (iter.first == "max_thread_x") { + else if (iter.first == "max_thread_x") max_thread_x = val->value; - } else if (iter.first == "max_thread_y") { + else if (iter.first == "max_thread_y") max_thread_y = val->value; - } else if (iter.first == "max_thread_z") { + else if (iter.first == "max_thread_z") max_thread_z = val->value; - } else if (iter.first == "max_vector_bytes") { - max_vector_bytes = val->value; - } else { + else LOG(FATAL) << "Invalid check item: " << iter.first; - } } return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, - max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, - max_vector_bytes); + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); } TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 4f1078165f34..a15190665949 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -43,7 +43,6 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; - int explicit_unroll_max_extent; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -58,9 +57,6 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); - TVM_ATTR_FIELD(explicit_unroll_max_extent) - .describe("The maximum extent of a loop that can be unrolled explicitly (-1 for infinite)") - .set_default(32); } }; @@ -75,12 +71,11 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll, int explicit_unroll_max_extent) + bool explicit_unroll) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll), - explicit_unroll_max_extent_(explicit_unroll_max_extent) {} + explicit_unroll_(explicit_unroll) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -170,12 +165,6 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); - if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && - explicit_unroll_) { - // Do not unroll too long loops - ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; - return For(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); - } Stmt body = op->body; Map vmap; Array unrolled; @@ -208,10 +197,7 @@ class LoopUnroller : public StmtExprMutator { // max extent of loop to auto unroll // this not not count the total steps, only count the number of loops int auto_max_extent_; - // Whether to explicitly unroll the loop instead of setting a pragma bool explicit_unroll_; - // The maximum extent of a loop that can be unrolled explicitly (-1 means infinite) - int explicit_unroll_max_extent_; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. @@ -224,7 +210,7 @@ class LoopUnroller : public StmtExprMutator { Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll, cfg->explicit_unroll_max_extent)(stmt); + cfg->explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py deleted file mode 100644 index 705556c65edf..000000000000 --- a/tests/python/unittest/test_ansor_feature.py +++ /dev/null @@ -1,150 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test feature extraction""" - -import math -import tempfile - -import tvm -from tvm import te, ansor - -from test_ansor_common import matmul_ansor_test - - -def fequal(a, b): - return math.fabs(a - b) < 1e-6 - - -def test_cpu_matmul(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) - s = dag.get_init_state() - C = s.stage_ops[2] - - i, j, k = s[C].iters - io, ii = s.split(C, i, [16]) - jo, ji = s.split(C, j, [8]) - s.reorder(C, [io, jo, k, ji, ii]) - s.vectorize(C, ji) - s.parallel(C, io) - s.parallel(C, jo) - s.unroll(C, k) - - target = tvm.target.create('llvm') - task = ansor.SearchTask(dag, "test", target) - names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] - - stage_0 = fea[0] - assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) - fea_dict = {} - for name, value in zip(names, stage_0): - fea_dict[name] = value - - for name in ["B0", "B1", "B2"]: - if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): - c_name = name - if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): - if fequal(fea_dict[name + ".stride"], 0.0): - b_name = name - else: - a_name = name - - assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) - assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) - assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) - assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) - assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) - - assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) - # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) - assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) - assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) - assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1)) - - -def test_cpu_fusion(): - def fusion_test(N, M): - A = te.placeholder((N, M), name='A') - B = te.compute((N, M), lambda i, j: A[i][j], name='B') - C = te.compute((N, M), lambda i, j: B[i][j], name='C') - return [A, B, C] - - dag = ansor.ComputeDAG(fusion_test(64, 32)) - s = dag.get_init_state() - s.compute_at(1, 2, s.stages[2].iters[1]) - - target = tvm.target.create('llvm') - task = ansor.SearchTask(dag, "test", target) - names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] - - found = False - for stage_fea in fea: - for i, (name, value) in enumerate(zip(names, stage_fea)): - if 'reuse_type.kSerialMultipleReadWrite' in name and value > 0.5: - assert fequal(stage_fea[i + 2], 1.0) - assert fequal(stage_fea[i + 3], math.log2(16 + 1)) - found = True - assert found - - -def test_gpu_feature(): - ctx = tvm.context("cuda", 0) - if not ctx.exist: - return - - json_records = "\n".join(( - """{"i": [["[\\"matmul_ansor_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""", - )) - - # load states - with tempfile.NamedTemporaryFile(mode='w') as f: - f.write(json_records) - f.flush() - inputs, results = ansor.LogReader(f.name).read_lines() - - inp = inputs[0] - dag = ansor.workload_key_to_dag(inp.task.workload_key) - task = ansor.SearchTask(dag, inp.task.workload_key, inp.task.target, None, ansor.HardwareParams(100000, 16, 64, 4, 64)) - - state = ansor.serialization.get_states_from_measure_inputs(inputs, task)[0] - state = dag.infer_bound_from_state(state) - fea = ansor.feature.get_per_stmt_features_from_states([state], task)[0] - names = ansor.feature.get_per_stmt_feature_names() - - # build feature dict - fea_dicts = [] - for i in range(len(fea)): - tmp_dict = {} - for j in range(len(names)): - tmp_dict[names[j]] = fea[i][j] - fea_dicts.append(tmp_dict) - - # check values - assert fequal(fea_dicts[0]['blockIdx_x_len'], math.log2(8 + 1)) - assert fequal(fea_dicts[0]['vthread_len'], math.log2(4 + 1)) - assert fequal(fea_dicts[1]['threadIdx_x_len'], math.log2(16 + 1)) - assert fequal(fea_dicts[0]['threadIdx_y_len'], math.log2(1 + 1)) - assert fequal(fea_dicts[2]['blockIdx_z_len'], math.log2(1 + 1)) - assert fequal(fea_dicts[0]['is_gpu'], 1.0) - - -if __name__ == "__main__": - test_cpu_matmul() - test_cpu_fusion() - test_gpu_feature() diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index d90be1a78421..35894354349f 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -26,7 +26,7 @@ from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu -def test_split_fuse_reorder_annotation(): +def test_split_fuse_reorder(): A, B, C = matmul_ansor_test(512, 512, 512) dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() @@ -61,541 +61,5 @@ def test_split_fuse_reorder_annotation(): assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 - s1.parallel(C, j1) - s1.unroll(C, j2) - s1.vectorize(C, j3) - s1.bind_thread(C, i1, "blockIdx.x") - s1.bind_thread(C, i2, "vthread") - s1.bind_thread(C, i3, "threadIdx.y") - - -def test_follow_split_follow_fused_split(): - A, B, C = matmul_ansor_test(512, 512, 512) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - C_global = s0.cache_write(C, "global") - - its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) - split_step0 = s0.transform_steps_size() - 1 - for level in range(1, 6): - tmp = s0.copy() - tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) - for i in range(0, level): - assert tmp[C].iters[i].range.extent == \ - tmp[C_global].iters[i].range.extent - - its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) - split_step1 = s0.transform_steps_size() - 1 - its = [] - for i0, i1 in zip(its0, its1): - its.append(i0) - its.append(i1) - s0.reorder(C, its) - for i in range(0, 5): - s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) - - for level in range(0, 4): - tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], level, False) - assert tmp[C].iters[level + 1].range.extent == \ - tmp[C_global].iters[0].range.extent - - for level in range(0, 4): - tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], level, True) - assert tmp[C].iters[level + 1].range.extent == \ - tmp[C_global].iters[1].range.extent - - -def test_compute_at_root_inline(): - dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) - s0 = dag.get_init_state() - - # data, padding, kernel = 0, 1, 2 - conv = s0.stage_ops[3] - # bias = 4 - bias_add = s0.stage_ops[5] - # bn_scale = 6 - bn_mul = s0.stage_ops[7] - # bn_offset = 8 - bn_add = s0.stage_ops[9] - relu = s0.stage_ops[10] - - s0.compute_inline(bn_add) - s0.compute_inline(bn_mul) - s0.compute_inline(bias_add) - s0.compute_at(conv, relu, s0[relu].iters[2]) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - s0.compute_root(conv) - s0.compute_root(bn_mul) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - "for i (None)\n" + \ - " for j (None)\n" + \ - " for k (None)\n" + \ - " for l (None)\n" + \ - " Bn_mul = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - -def test_cache_read_write(): - N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( - 1, 1), (1, 1) - - data = te.placeholder((N, CI, H, W), name='Data') - kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') - k0, k1 = te.compute(kernel_data.shape, - lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), - name='Kernel_split') - kernel = te.compute(kernel_data.shape, - lambda *i: k0(*i) + k1(*i), - name='Kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) - relu = topi.nn.relu(conv) - add = topi.add(data, relu) - - dag = ansor.ComputeDAG([data, kernel_data, add]) - s0 = dag.get_init_state() - - pad_temp = s0.stage_ops[1] - kernel_split = s0.stage_ops[3] - - # 0: init state - ori_its = s0[add].iters - its = s0.split(add, s0[add].iters[0], [2]) - s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) - s0.compute_inline(relu) - - # 1: simple cache_write with compute_at - conv_global = s0.cache_write(conv, "global") - s0.compute_at(conv_global, conv, s0[conv].iters[3]) - - # 2: simple cache_read with compute_at - kernel_global = s0.cache_read(kernel, "global", [conv_global]) - s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 3: two level cache_read with compute_at - # preparing for GPU's shared memory & local memory - pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) - pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) - s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) - s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) - - # 4: cache_read with multi readers - # This stage cannot be compute at to its consumer - s0.cache_read(data, "global", [pad_temp, add]) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 5: cache_write with multi outputs - # TVM's cache_write actually has a bug with this case: - # - # After schedule.cache_write, TVM generate one new stage: - # From: kernel_data -> kernel_split -> kernel - # To: kernel_data -> kernel_split_global -> kernel_split -> kernel - # - # But with topo sort analyse, we get: - # // kernel_data -> kernel_split_global -> kernel_split -> kernel - # \ / - # ----------------> kernel_split ----------------> - # - # Seems there's bug with the input/output tensor. Such multi outputs case - # should be unusual, so we make some hack on DoCacheWrite - # To be fixed in the future - s0.cache_write(kernel_split, "global") - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0_c (0,512)\n" + \ - " for i1_c (0,512)\n" + \ - " for i2_c (0,3)\n" + \ - " for i3_c (0,3)\n" + \ - " Kernel_split.global = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - -def test_rfactor(): - A, B, C = matmul_ansor_test(8, 8, 512) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - ko, ki = s0.split(C, s0[C].iters[2], [16]) - - s1 = s0.copy() - s1.rfactor(C, ko, 2) - assert str(s1) == \ - "Placeholder: A, B\n" + \ - "for i (0,8)\n" + \ - " for j (0,8)\n" + \ - " for k_o (0,32)\n" + \ - " for k_i (0,16)\n" + \ - " C.rf = ...\n" + \ - "for ax0 (0,8)\n" + \ - " for ax1 (0,8)\n" + \ - " for k_o_v (0,32)\n" + \ - " C.repl = ...\n" - - s2 = s0.copy() - s2.rfactor(C, ki, 2) - assert str(s2) == \ - "Placeholder: A, B\n" + \ - "for i (0,8)\n" + \ - " for j (0,8)\n" + \ - " for k_i (0,16)\n" + \ - " for k_o (0,32)\n" + \ - " C.rf = ...\n" + \ - "for ax0 (0,8)\n" + \ - " for ax1 (0,8)\n" + \ - " for k_i_v (0,16)\n" + \ - " C.repl = ...\n" - - -def vcf_init_common(): - A, B, C = matmul_ansor_test(512, 512, 512) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - B_shared = s0.cache_read(B, "shared", [C]) - B_local = s0.cache_read(B_shared, "local", [C]) - A_shared = s0.cache_read(A, "shared", [C]) - A_local = s0.cache_read(A_shared, "local", [C]) - - return A_shared, A_local, B_shared, B_local, C, dag, s0 - - -def vcf_check_common(dag, state): - s, args = dag.apply_steps_from_state(state) - # To check if every vectorize loop transforms to ramp expr successfully - # TODO(jcf94): Find a better way to process the check in AST - print(tvm.lower(s, args)) - - if tvm.context("cuda", 0).exist: - tgt = tvm.target.cuda() - mod = tvm.build(s, args, tgt) - # To check if every vectorize loop transforms to correct instruction - print(mod.imported_modules[0].get_source()) - - ctx = tvm.context("cuda", 0) - dtype = dag.tensors[0].dtype - a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - else: - print("CUDA device not found, skip this test.") - - -def test_vectorized_cooperative_fetching_x(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() - - its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) - s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) - s0.bind_thread(C, s0[C].iters[1], "vthread") - s0.fuse(C, [s0[C].iters[2], s0[C].iters[3]]) - s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0[C].iters[3]) - fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [64, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.vectorize(B_shared, its[2]) - s0.compute_at(B_local, C, s0[C].iters[4]) - fused_it = s0.fuse(B_local, s0[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0[C].iters[3]) - fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [64, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.vectorize(A_shared, its[2]) - s0.compute_at(A_local, C, s0[C].iters[4]) - fused_it = s0.fuse(A_local, s0[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - vcf_check_common(dag, s0) - - -def test_vectorized_cooperative_fetching_xy(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() - - its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) - s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) - s0.bind_thread(C, s0[C].iters[1], "vthread") - s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") - s0.bind_thread(C, s0[C].iters[3], "threadIdx.y") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0[C].iters[4]) - fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [8, 8, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.bind_thread(B_shared, its[2], "threadIdx.y") - s0.vectorize(B_shared, its[3]) - s0.compute_at(B_local, C, s0[C].iters[5]) - fused_it = s0.fuse(B_local, s0[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0[C].iters[4]) - fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [8, 8, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.bind_thread(A_shared, its[2], "threadIdx.y") - s0.vectorize(A_shared, its[3]) - s0.compute_at(A_local, C, s0[C].iters[5]) - fused_it = s0.fuse(A_local, s0[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - vcf_check_common(dag, s0) - - -@tvm._ffi.register_func -def test_intrin_gemv(): - m = 16 - l = 64 - a = te.placeholder((l,), name='a') - b = te.placeholder((l, m), name='b') - k = te.reduce_axis((0, l), name='k') - c = te.compute((m,), lambda i: te.sum(a[k] * b[k, i], axis=k), name='c') - Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", - offset_factor=1, strides=[1]) - Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", - offset_factor=1, strides=[te.var("s0"), 1]) - Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", - offset_factor=1, strides=[1]) - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - aa, bb = ins - cc = outs[0] - ib.emit(tvm.tir.call_extern("float32", "gemv_update", - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"))) - return ib.get() - return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) - -def test_tensorize(): - A, B, C = matmul_ansor_test(1024, 512, 64) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - its = s0.split(C, s0[C].iters[1], [16]) - s0.tensorize(C, its[1], "test_intrin_gemv") - - sch, tensors = dag.apply_steps_from_state(s0) - tvm.lower(sch, tensors, simple_mode=True) - if __name__ == "__main__": - test_split_fuse_reorder_annotation() - test_follow_split_follow_fused_split() - test_compute_at_root_inline() - test_cache_read_write() - test_rfactor() - test_vectorized_cooperative_fetching_x() - test_vectorized_cooperative_fetching_xy() - test_tensorize() + test_split_fuse_reorder() diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index d457dd2c55cc..f8d41edd27dd 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -62,24 +62,6 @@ def test_measure_local_builder_runner(): assert mress[0].error_no == 0 -def test_measure_local_builder_rpc_runner(): - dag, s0 = get_tiled_matmul() - - tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) - minp = ansor.MeasureInput(task, s0) - - local_builder = ansor.LocalBuilder() - measure_ctx = ansor.LocalRPCMeasureContext() - rpc_runner = measure_ctx.runner - - bress = local_builder.build([minp]) - assert bress[0].error_no == 0 - mress = rpc_runner.run([minp], bress) - assert mress[0].error_no == 0 - - if __name__ == "__main__": test_serialization() test_measure_local_builder_runner() - test_measure_local_builder_rpc_runner() diff --git a/tests/python/unittest/test_ansor_relay_integration.py b/tests/python/unittest/test_ansor_relay_integration.py deleted file mode 100644 index 1ad507e2f371..000000000000 --- a/tests/python/unittest/test_ansor_relay_integration.py +++ /dev/null @@ -1,114 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Test Relay Integration """ - -import tempfile -import numpy as np - -import tvm -from tvm import ansor, relay -import tvm.contrib.graph_runtime as runtime -from tvm.relay.testing import dqn - -def test_tune_dense_graph(): - def dense_graph(N, dtype="float32"): - ori_data = relay.var("data", shape=(N, N), dtype=dtype) - weight = relay.var("weight", shape=(N, N), dtype=dtype) - data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) - dense = relay.nn.dense(data, weight, out_dtype=dtype) - dense = relay.add(dense, weight) - dense = relay.nn.dense(dense, weight, out_dtype=dtype) - return ori_data, weight, dense - - N = 128 - data, weight, dense = dense_graph(N) - mod = relay.Function([data, weight], dense) - mod = tvm.IRModule.from_expr(mod) - - ctx = tvm.context("llvm") - target = tvm.target.create("llvm") - d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx) - w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx) - wkl_keys, wkl_weights = ansor.extract_from_program(mod, {}, target=target) - - assert len(wkl_keys) == 2 - assert len(wkl_weights) == 2 - - tasks = [] - for wkl_key in wkl_keys: - dag = ansor.workload_key_to_dag(wkl_key) - tasks.append(ansor.SearchTask(dag, wkl_key, target)) - - tuner = ansor.SimpleTaskScheduler(tasks) - measure_ctx = ansor.LocalRPCMeasureContext() - with tempfile.NamedTemporaryFile() as fp: - tuner.tune(ansor.TuneOption(n_trials=2, runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile(fp.name)])) - with ansor.apply_history_best(fp.name): - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build( - mod, target=target) - - m = runtime.create(graph, lib, ctx) - m.set_input('data', d) - m.set_input('weight', w) - m.run() - res = m.get_output(0) - - del measure_ctx - - d = d.asnumpy() - d = d * 2 - w = w.asnumpy() - d = np.dot(d, np.transpose(w)) - d = d + w - d = np.dot(d, np.transpose(w)) - - tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5) - - -def test_tune_dqn(): - mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') - target = tvm.target.create('llvm') - - wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) - - tasks = [] - for wkl_key in wkl_keys: - dag = ansor.workload_key_to_dag(wkl_key) - tasks.append(ansor.SearchTask(dag, wkl_key, target)) - - assert len(tasks) == 5 - - tuner = ansor.SimpleTaskScheduler(tasks) - measure_ctx = ansor.LocalRPCMeasureContext() - with tempfile.NamedTemporaryFile() as fp: - tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile('tmp.json')]), - search_policy='sketch.random') - with ansor.apply_history_best('tmp.json'): - ansor.prepare_layout_rewrite(mod, params, target) - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build(mod, target=target) - ansor.finish_layout_rewrite() - - del measure_ctx - -if __name__ == "__main__": - test_tune_dense_graph() - test_tune_dqn() - diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index deff561a4547..984434b9c58b 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -79,90 +79,5 @@ def test_search_basic(): t.start() t.join() - -def test_search_xgb_model_rpc_runner(): - measure_ctx = ansor.LocalRPCMeasureContext() - search_common(seed=456787236, cost_model=ansor.XGBModel(), - runner=measure_ctx.runner) - - -def test_search_opencl(): - if tvm.context("opencl", 0).exist: - measure_ctx = ansor.LocalRPCMeasureContext() - search_common("opencl", 380344973, measure_ctx.runner) - else: - print("OpenCL device not found, skip this test.") - - -def test_search_cuda(): - if tvm.context("cuda", 0).exist: - measure_ctx = ansor.LocalRPCMeasureContext() - search_common("cuda", 903667810, measure_ctx.runner) - else: - print("CUDA device not found, skip this test.") - - -def test_search_custom_sketch_rule(): - def meet_condition_func(meta_policy, state, stage_id): - # Apply and Skip the Rest if this function does not return - pass - - # Expecting: - # i.0 - # i.1 - # i.2 - # j.0 - # j.1 - # ax0 - # ax1 - # B.global - # j.2 - # k - # C - def apply_func1(meta_policy, state, stage_id): - # Stage by stage way - ret = [] - if stage_id == 2: - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - state.split(2, state.stages[2].iters[0], [4, 4]) - state.split(2, state.stages[2].iters[3], [4, 4]) - ret.append([state.state_object, stage_id - 1]) - elif stage_id == 1: - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - state.cache_read(1, "global", [2]) - state.compute_at(2, 3, state.stages[3].iters[4]) - ret.append([state.state_object, stage_id - 1]) - else: - ret.append([state, stage_id - 1]) - return ret - - def apply_func2(meta_policy, state, stage_id): - # More template like way - ret = [] - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - - state.split(2, state.stages[2].iters[0], [4, 4]) - state.split(2, state.stages[2].iters[3], [4, 4]) - state.cache_read(1, "global", [2]) - state.compute_at(2, 3, state.stages[3].iters[4]) - - ret.append([state.state_object, -1]) - return ret - - measure_ctx = ansor.LocalRPCMeasureContext() - search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreloadCustomSketchRule( - meet_condition_func, apply_func1)], - params={'disable_change_compute_location': 1}) - search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreloadCustomSketchRule( - meet_condition_func, apply_func2)], - params={'disable_change_compute_location': 1}) - - if __name__ == "__main__": test_search_basic() - test_search_xgb_model_rpc_runner() - test_search_opencl() - test_search_cuda() - test_search_custom_sketch_rule() diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py deleted file mode 100644 index 53cf2059c1f3..000000000000 --- a/tests/python/unittest/test_ansor_task_scheduler.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test the task scheduler """ - -import threading - -import tvm -from tvm import ansor - -from test_ansor_common import matmul_ansor_test - -def test_task_scheduler_basic(): - N = 128 - A, B, C = matmul_ansor_test(N, N, N) - dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") - task1 = ansor.SearchTask(dag, "test", tgt) - task2 = ansor.SearchTask(dag, "test", tgt) - - def basic_test_func(task1, task2): - def objective(costs): - return sum(costs) - - task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) - tune_option = ansor.TuneOption(n_trials=3, runner='local') - task_scheduler.tune(tune_option) - - # Ansor search process with local runner has some modification on thread - # binding, wrap this to a subprocess to eliminate the impacts to other tests - t = threading.Thread(target=basic_test_func, - kwargs={'task1': task1, 'task2': task2}) - t.start() - t.join() - - -if __name__ == "__main__": - test_task_scheduler_basic() diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 12c686634548..68639940bb05 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -110,31 +110,7 @@ def test_unroll_single_count_loops(): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert ret == stmt -def test_unroll_explicitly_max_extent(): - n = 64 - A = te.placeholder((n,), name='A') - B = te.compute((n,), lambda *i: A(*i), name='B') - s = te.create_schedule(B.op) - s = s.normalize() - dom_map = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, dom_map) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - - with tvm.transform.PassContext(config={ - "tir.UnrollLoop": {"explicit_unroll_max_extent": n-1} - }): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert tvm.ir.structural_equal(ret, stmt) - - with tvm.transform.PassContext(config={ - "tir.UnrollLoop": {"explicit_unroll_max_extent": n} - }): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert not tvm.ir.structural_equal(ret, stmt) - - if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() - test_unroll_explicitly_max_extent() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 7dd782f5b622..e0e455667889 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1295,75 +1295,6 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } -/*! - * \brief utility function for kernel_layout_transform - */ -inline void parse_kernel_layout(const String& layout, - Array* shape, - std::vector* axes) { - int32_t factor = 0; - std::string axis = ""; - for (char c : std::string(layout)) { - if (c >= 'A' && c <= 'z') { - axis += c; - if (factor != 0) { - shape->push_back(factor); - factor = 0; - } - } else if (c >= '0' && c <= '9') { - factor = factor * 10 + c - '0'; - if (!axis.empty()) { - axes->push_back(axis); - axis = ""; - } - } else { - LOG(FATAL) << "Invalid layout " << layout; - } - } - if (!axis.empty()) { - axes->push_back(axis); - } -} - -/*! - * \brief Transform the kernel layout according to \p src_layout and \p dst_layout - * \param src the source input. - * \param src_layout the source layout. - * \param dst_layout the destination layout. - * \param name output tensor name. - * \param tag output tensor tag. - * \return A tensor with shape in \p dst_layout - */ -inline Tensor kernel_layout_transform(const Tensor& src, - const String& src_layout, - const String& dst_layout, - const String name = "T_kernel_layout_trans", - const String tag = kInjective) { - Array src_shape; - std::vector src_axes; - Array dst_shape; - std::vector dst_axes; - - parse_kernel_layout(src_layout, &src_shape, &src_axes); - parse_kernel_layout(dst_layout, &dst_shape, &dst_axes); - return compute( - dst_shape, [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices; - for (const std::string& src_axis : src_axes) { - PrimExpr src_index = 0; - CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); - for (size_t i = 0; i < dst_axes.size(); ++i) { - if (dst_axes[i] == src_axis) { - src_index = src_index * dst_shape[i] + dst_indices_expr[i]; - } - } - src_indices.push_back(src_index); - } - return src(src_indices); - }, name, tag); -} - /*! * \brief Get the shape of input tensor. * \param src the input tensor. diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 6800129c12aa..4c7941b49692 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm -from tvm import te, ansor +from tvm import te from .pad import pad from .util import get_pad_tuple @@ -342,37 +342,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - if ansor.GLOBAL_SCOPE.topi_in_compute_rewrite_mode: - # infer shape for the rewritten layout - if len(Filter.shape) >= 10: - # For cpu tile structure SSRSRS - base = len(Filter.shape) - 10 - kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] - kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] - channel = Filter.shape[4 + base] * Filter.shape[8 + base] - num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] - for i in range(base + 2): - num_filter *= Filter.shape[i] - elif len(Filter.shape) == 6: - # For cpu tile structure SRS - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] - kernel_h = Filter.shape[2] - kernel_w = Filter.shape[3] - channel = Filter.shape[4] - elif len(Filter.shape) == 5: - # For cpu tile structure SRS - num_filter = Filter.shape[0] * Filter.shape[4] - kernel_h = Filter.shape[1] - kernel_w = Filter.shape[2] - channel = Filter.shape[3] - elif len(Filter.shape) == 4: - num_filter, kernel_h, kernel_w, channel = Filter.shape - else: - raise ValueError("Don't know how to infer layout for filter shape: %s. " \ - "You can add a new branch for it to fix this." % str(Filter)) - else: - kernel_h, kernel_w, channel, num_filter = Filter.shape - + kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -392,9 +362,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): lambda nn, yy, xx, ff: te.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[ry, rx, rc, ff].astype(out_dtype) - , axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc", attrs={"layout_free_placeholders": [Filter]}) + Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc") return Output diff --git a/tutorials/ansor/README.txt b/tutorials/ansor/README.txt deleted file mode 100644 index 85b6ba401dae..000000000000 --- a/tutorials/ansor/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _tutorial-ansor-auto-schedule: - -Ansor: Template Free Auto Scheduling ------------------------------------- diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py deleted file mode 100644 index 03f1b24a768e..000000000000 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ /dev/null @@ -1,179 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Auto-scheduling High Performance Convolution on NVIDIA GPUs -=========================================================== -**Author**: `Lianmin Zheng `_, \ - `Chengfan Jia `_, \ - `Minmin Sun `_, \ - `Zhao Wu `_ - -This is an tutorial for searching high performance schedule for NVIDIA GPU using -Ansor auto-scheduler. By running Ansor on this template, we can outperform the -vendor provided library CuDNN in many cases. -""" - -###################################################################### -# Install dependencies -# -------------------- -# To use autotvm package in tvm, we need to install some extra dependencies. -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user psutil xgboost tornado -# -# To make TVM run faster in tuning, it is recommended to use cython -# as FFI of tvm. In the root directory of tvm, execute -# -# .. code-block:: bash -# -# pip3 install --user cython -# sudo make cython3 -# -# Now return to python code. Import packages. - -import random -import sys - -import numpy as np -import tvm -import topi -from topi.testing import conv2d_nchw_python -from tvm import te - -# the module is called `ansor` -from tvm import ansor - -###################################################################### -# Step 1: Define the search task -# ------------------------------- -# There are plenty of useful schedule primitives in tvm. You can also find -# some tutorials that describe them in more details, such as -# (1). :ref:`opt-conv-gpu` -# (2). `Optimizing DepthwiseConv on NVIDIA GPU `_ -# -# It's usually a hard job if one wants to get a high performance schedule for a -# specific workload. Even writing an AutoTVM tunable template needs user to have -# expertises on how each schedule primitive works as well as how they finally -# reflect on the hardward architecture. -# -# However, with Ansor this will be quite simple. Firstly, define the target workload. -# Both :code:`tvm.te` API or topi op API are fine to be used. -# -# We can use the retuned :code:`Tensors` to create a ComputeDAG just like what we do -# in :ref:`ansor-simple-subgraph`, while the way to use workload registry is more -# recommended. - -# Use an extra function decorator to regist this workload -@ansor.register_workload_func -def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, KH, KW), name='kernel') - conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32') - - return [data, kernel, conv] - -###################################################################### -# Step 2: Search through the schedule space -# ------------------------------------------ -# We pick the last layer on resnet as test case. -# Since our space is very large, :code:`XGBModel` is most suitable -# for our case. Here we only do 20 trials for demonstration. -# In practice, making 1000 trials usually can find some good kernels -# for this workload. - -tgt = tvm.target.cuda() - -# The last layer in resnet -N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) -# Generate workload key with the ansor API -wkl_key = ansor.make_workload_key_func(conv2d_nchw, (N, H, W, CO, CI, KH, KW, strides, padding)) -# Generate ComputeDAG using the workload key -dag = ansor.workload_key_to_dag(wkl_key) -task = ansor.SearchTask(dag, wkl_key, target=tgt) - -log_file = "conv2d_nchw.json" -seed = 0 -random.seed(seed) -cost_model = ansor.XGBModel(seed=seed) -search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) - -######################################################################### -# The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. -# -# Use local gpu, measure 10 times for every schedule to reduce variance. The timeout -# for each running is set to 4 seconds. -# -# During the searching process, we may generate several invalid schedules and they -# will be filtered out. It's fine to see "Encountered errors during feature extraction." -# in the tuning logs. -# :code:`ansor.LogToFile` callback will log the tuning results into a -# log file, which can be used to get the best config later. -# :code:`ansor.PreloadMeasuredStates` callback will load measured states -# from history log before schedule search, we can add this callback to make -# sure a same schedule will never be measured for multiple times. - -measure_ctx = ansor.LocalRPCMeasureContext(repeat=3, min_repeat_ms=100, timeout=4) -tune_option = ansor.TuneOption(n_trials=20, - runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) -s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) - -print("==== Get Lowered Stmt ====") -print(tvm.lower(s, arg_bufs, simple_mode=True)) - -# Release the RPC runner environment -del measure_ctx - -######################################################################### -# From the example lower result showed above, we can see that Ansor has tried -# techniques such as `Shared Memory Cooperative Fetching`, `Kernel Fusion`, -# `Axis unroll`, `Axis Vectorize` and so on. There is no need for users to care -# about the details, and Ansor will catch them well. -# -# Finally we can directly use the returned result to get the generated schedule, -# while in the following tutorial we'll show how to inspect the best config from -# log file, check correctness, and measure running time. - -# Get history best from log file -inp, res = ansor.best_measure_pair_in_file(log_file) -# Get the task ComputeDAG from log result -dag = ansor.workload_key_to_dag(inp.task.workload_key) -# Apply log result to TVM schedule -s, arg_bufs = dag.apply_steps_from_state(inp.state) -func = tvm.build(s, arg_bufs, target=tgt) - -# check correctness -a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32) -w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32) -c_np = conv2d_nchw_python(a_np, w_np, strides, padding) - -ctx = tvm.gpu() -a_tvm = tvm.nd.array(a_np, ctx=ctx) -w_tvm = tvm.nd.array(w_np, ctx=ctx) -c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx) -func(a_tvm, w_tvm, c_tvm) - -tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2) - -# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise -# and the overhead of kernel launch. You can also use nvprof to validate the result. -evaluator = func.time_evaluator(func.entry_name, ctx, number=400) -print('Time cost of this operator: %f s' % evaluator(a_tvm, w_tvm, c_tvm).mean) - diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py deleted file mode 100644 index 00bef82cf855..000000000000 --- a/tutorials/ansor/tune_simple_subgraph.py +++ /dev/null @@ -1,193 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -.. _ansor-simple-subgraph: - -Writing compute expression and Using Ansor auto-scheduler -========================================================= -**Author**: `Lianmin Zheng `_, \ - `Chengfan Jia `_, \ - `Minmin Sun `_, \ - `Zhao Wu `_ - -This is an introduction tutorial to the auto-scheduler module in TVM. - -There are two steps in auto-scheduling. -The first step is defining the target task. -The second step is running a search algorithm to auto explore the schedule. -In this tutorial, you can learn how to perform these two steps in TVM. -The whole workflow is illustrated by a matrix multiplication with bias add example. -""" - -###################################################################### -# Install dependencies -# -------------------- -# To use Ansor package in TVM, we need to install some extra dependencies. -# This step (installing xgboost) can be skipped as it doesn't need XGBoost -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user psutil xgboost -# -# To make TVM run faster in tuning, it is recommended to use cython -# as FFI of TVM. In the root directory of TVM, execute -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user cython -# sudo make cython3 -# -# Now return to python code. Import packages. - -import random -import sys - -import numpy as np -import tvm -from tvm import te - -# the module is called `ansor` -from tvm import ansor - -###################################################################### -# Step 1: Define the target compute subgraph -# ------------------------------------------- -# In this section, we will write a deterministic TVM compute expression code -# to a compute subgraph. -# -# .. note:: Comparing to :ref:`tutorials-autotvm-sec` -# -# In Ansor, we do not need users to provide a schedule template, the only input -# is the compute expression writing by :code:`tvm.te` API or topi op API. -# -# Here is how we implement a matrix multiplication subgraph in TVM. - -# Matmul with bias add -def matmul_add(N, L, M, dtype): - A = te.placeholder((N, L), name='A', dtype=dtype) - B = te.placeholder((L, M), name='B', dtype=dtype) - C = te.placeholder((N, M), name='C', dtype=dtype) - - k = te.reduce_axis((0, L), name='k') - mul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), - name='Mul') - D = te.compute((N, M), lambda i, j: C[i, j] + mul[i, j], name='D') - - return [A, B, C, D] - -###################################################################### -# Step 2: Search through the schedule space -# ------------------------------------------ -# In step 1, we build the compute subgraph. -# The next step is to pick a cost model as well as a search policy and explore the -# possible schedule. -# -# Auto-scheduler in TVM -# ^^^^^^^^^^^^^^^^^^^^^ -# The job for the Ansor auto-scheduler can be described by following pseudo code -# -# .. code-block:: c -# -# ct = 0 -# while ct < max_number_of_trials: -# auto generate a batch of schedules -# measure this batch of schedules on real hardware and get results -# ct += batch_size -# -# When proposing the next batch of schedules, Ansor can take different cost models to -# guide the schedule generating process. -# -# * :code:`RandomModel`: Generate and take new schedule randomly -# * :code:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step -# -# XGBModel can explore more efficiently and find better schedules. - -################################################################ -# Begin tuning -# ^^^^^^^^^^^^ -# Here we continue our matrix multiplication example. -# -# The :code:`ansor.ComputeDAG` takes the Tensor list as input, and generates -# a dag structure. During which process, :code:`ansor.ComputeDAG` will -# do some analyzes with the target subgraph and the results will be used in -# search policy later. -# -# Then we create the :code:`tvm.target` and a tuning task. - -N, L, M = 128, 128, 128 -A, B, C, D = matmul_add(N, L, M, 'float32') -dag = ansor.ComputeDAG([A, B, C, D]) - -print(dag) -print(dag.access_analyzer) - -tgt = tvm.target.create("llvm") -task = ansor.SearchTask(dag, "test", tgt) - -################################################################ -# Next, we choose random model and create a default search policy: -# :code:`ansor.SketchSearchPolicy`. -# -# We only make 5 trials in this tutorial for demonstration. In practice, -# you can do more trials according to your time budget. -# :code:`ansor.LogToFile` callback will log the tuning results into a -# log file, which can be used to get the best config later. -# :code:`ansor.PreloadMeasuredStates` callback will load measured states -# from history log before schedule search, we can add this callback to make -# sure a same schedule will never be measured for multiple times. - -log_file = "matmul_add.json" - -seed = 0 -random.seed(seed) -cost_model = ansor.RandomModel() -search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) - -tune_option = ansor.TuneOption(n_trials=5, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) - -################################################################ -# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high -# performance schedule for the target subgraph automatically. -# -# The returned result will be a :code:`te.schedule` and a list of :code:`te.Tensor`, -# which can be used as the input of :code:`tvm.lower` or :code:`tvm.build`. - -s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, - tune_option=tune_option) - -print("==== Get Lowered Stmt ====") -print(tvm.lower(s, arg_bufs, simple_mode=True)) - -######################################################################### -# Check the correctness to make sure we generate a right schedule. - -func = tvm.build(s, arg_bufs) - -# check correctness -a_np = np.random.uniform(size=(N, L)).astype(np.float32) -b_np = np.random.uniform(size=(L, M)).astype(np.float32) -c_np = np.random.uniform(size=(N, M)).astype(np.float32) -d_np = a_np.dot(b_np) + c_np - -d_tvm = tvm.nd.empty(d_np.shape) -func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm) - -tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-2) diff --git a/tutorials/autotvm/README.txt b/tutorials/autotvm/README.txt index 4ad36c000e3c..38e3b3343f4e 100644 --- a/tutorials/autotvm/README.txt +++ b/tutorials/autotvm/README.txt @@ -1,4 +1,4 @@ .. _tutorials-autotvm-sec: -AutoTVM: Template Based Auto Tuning ------------------------------------ +Auto tuning +----------- From d6d6b859214b6519fa6aae894d73ab38e95dd902 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 15:14:43 +0800 Subject: [PATCH 41/78] Code clean for minimum Ansor system --- python/tvm/ansor/__init__.py | 5 +- python/tvm/ansor/auto_schedule.py | 29 +- python/tvm/ansor/compute_dag.py | 17 +- python/tvm/ansor/cost_model/__init__.py | 20 - python/tvm/ansor/cost_model/cost_model.py | 46 - python/tvm/ansor/measure.py | 205 +-- python/tvm/ansor/utils.py | 66 - src/ansor/auto_schedule.cc | 1 - src/ansor/compute_dag.cc | 359 +--- src/ansor/compute_dag.h | 4 - src/ansor/cost_model/cost_model.cc | 198 --- src/ansor/cost_model/cost_model.h | 157 -- src/ansor/feature.cc | 1573 ----------------- src/ansor/feature.h | 80 - src/ansor/loop_state.cc | 549 ------ src/ansor/loop_state.h | 45 +- src/ansor/measure.cc | 42 - src/ansor/measure.h | 36 - src/ansor/search_policy/empty_policy.cc | 98 + src/ansor/search_policy/empty_policy.h | 81 + src/ansor/search_policy/search_policy.cc | 60 +- src/ansor/search_policy/search_policy.h | 38 - .../search_policy/sketch_search_policy.cc | 1541 ---------------- .../search_policy/sketch_search_policy.h | 157 -- src/ansor/search_policy/utils.cc | 744 -------- src/ansor/search_policy/utils.h | 483 ----- src/ansor/serialization.cc | 175 +- src/ansor/transform_step.cc | 602 ------- src/ansor/transform_step.h | 427 +---- tests/python/unittest/test_ansor_common.py | 11 +- .../python/unittest/test_ansor_compute_dag.py | 27 - .../unittest/test_ansor_search_policy.py | 5 +- 32 files changed, 210 insertions(+), 7671 deletions(-) delete mode 100644 python/tvm/ansor/cost_model/__init__.py delete mode 100644 python/tvm/ansor/cost_model/cost_model.py delete mode 100644 src/ansor/cost_model/cost_model.cc delete mode 100644 src/ansor/cost_model/cost_model.h delete mode 100644 src/ansor/feature.cc delete mode 100644 src/ansor/feature.h create mode 100644 src/ansor/search_policy/empty_policy.cc create mode 100644 src/ansor/search_policy/empty_policy.h delete mode 100644 src/ansor/search_policy/sketch_search_policy.cc delete mode 100644 src/ansor/search_policy/sketch_search_policy.h delete mode 100644 src/ansor/search_policy/utils.cc delete mode 100644 src/ansor/search_policy/utils.h diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index ccd8f27b71c1..93a82f073ac3 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -27,9 +27,8 @@ # Shortcut from .compute_dag import ComputeDAG from .auto_schedule import SearchTask, TuneOption, HardwareParams, \ - auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext -from .cost_model import RandomModel + auto_schedule, EmptyPolicy +from .measure import MeasureInput, LocalBuilder, LocalRunner from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file from .workload_registry import register_workload_func, \ diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 37e622018658..8fddac567529 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,14 +22,13 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel from . import _ffi_api @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): - """ - The parameters of target hardware + """ The parameters of target hardware, this is used to guide the search process of + SearchPolicy. Parameters ---------- @@ -48,8 +47,7 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, @tvm._ffi.register_object("ansor.SearchTask") class SearchTask(Object): - """ - The meta-information of a search task + """ The meta-information of a search task Parameters ---------- @@ -69,22 +67,21 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): """ The base class for search policy """ - def continue_search(self, task, num_measure, verbose, measurer): - return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, - num_measure, verbose, measurer) - def set_task(self, task): - _ffi_api.SearchPolicySetTask(self, task) - def set_verbose(self, verbose): - _ffi_api.SearchPolicySetVerbose(self, verbose) +@tvm._ffi.register_object("ansor.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ This is an example empty search policy which will always generate + the init state of target ComputeDAG. + """ + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) - def run_callbacks(self, callbacks): - _ffi_api.SearchPolicyRunCallbacks(self, callbacks) @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): - """Callback function before or after search process""" + """ Callback function before or after search process """ + @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): @@ -164,7 +161,7 @@ def auto_schedule(workload, target=None, """ if isinstance(search_policy, str): if search_policy == 'default': - search_policy = SketchSearchPolicy(RandomModel()) + search_policy = EmptyPolicy() else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index acfec66a166a..d591d615d1c5 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -59,7 +59,7 @@ def apply_steps_from_state(self, state): args : List[Tensor] """ state_obj = state if isinstance(state, StateObject) else state.state_object - return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite_level) + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) def print_python_code_from_state(self, state): """ @@ -75,18 +75,3 @@ def print_python_code_from_state(self, state): """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) - - def infer_bound_from_state(self, state): - """ - Infer bound for a state - - Parameters - ---------- - state : StateObject - - Returns - ------- - state : State - """ - state_obj = state if isinstance(state, StateObject) else state.state_object - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py deleted file mode 100644 index 1454da451b61..000000000000 --- a/python/tvm/ansor/cost_model/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import, redefined-builtin -""" Cost model that estimates the performance of programs """ - -from .cost_model import RandomModel \ No newline at end of file diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py deleted file mode 100644 index 605db14c19c3..000000000000 --- a/python/tvm/ansor/cost_model/cost_model.py +++ /dev/null @@ -1,46 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Cost model that estimates the performance of programs """ -import ctypes -import numpy as np - -import tvm._ffi -from tvm.runtime import Object -from .. import _ffi_api - - -@tvm._ffi.register_object("ansor.CostModel") -class CostModel(Object): - """The base class for cost model""" - - -@tvm._ffi.register_object("ansor.RandomModel") -class RandomModel(Object): - """A model returns random estimation for all inputs""" - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.RandomModel) - - -@tvm._ffi.register_func("ansor.cost_model.random_number") -def random_number(n, return_ptr): - """ A random number generator func for c++'s RandomModel """ - if n == 0: - return - return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) - array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) - array_wrapper[:] = np.random.uniform(0, 1, (n,)) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 46c3e3aabd5d..af0eddc59653 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -35,13 +35,9 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server -from tvm.autotvm.measure.measure_methods import set_cuda_target_arch -from tvm.contrib import tar, ndk + from . import _ffi_api -from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ - check_remote +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout LOGGER = logging.getLogger('ansor') @@ -178,104 +174,6 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) -@tvm._ffi.register_object("ansor.ProgramMeasurer") -class ProgramMeasurer(Object): - """ - Parameters - ---------- - builder : Builder - runner : Runner - callbacks : List[MeasureCallback] - verbose : Int - max_continuous_error : Float - """ - - def __init__(self, builder: Builder, runner: Runner, - callbacks: List[MeasureCallback], - verbose: int, max_continuous_error: int = -1): - self.__init_handle_by_constructor__( - _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error) - -@tvm._ffi.register_object("ansor.RPCRunner") -class RPCRunner(Runner): - """ - Parameters - ---------- - key : Str - host : Str - port : Int - priority : Int - n_parallel : Int - timeout : Int - number : Int - repeat : Int - min_repeat_ms : Int - cooldown_interval : Float - """ - - def __init__(self, key, host, port, priority=1, - n_parallel=1, - timeout=10, - number=3, - repeat=1, - min_repeat_ms=0, - cooldown_interval=0.0): - self.__init_handle_by_constructor__( - _ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval) - - if check_remote(key, host, port, priority, timeout): - LOGGER.info("Get devices for measurement successfully!") - else: - raise RuntimeError("Cannot get remote devices from the tracker. " - "Please check the status of tracker by " - "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " - "and make sure you have free devices on the queue status.") - - -class LocalRPCMeasureContext: - """ A context wrapper for running RPCRunner locally. - This will launch a local RPC Tracker and local RPC Server. - - Parameters - ---------- - priority : Int - n_parallel : Int - timeout : Int - number : Int - repeat : Int - min_repeat_ms : Int - cooldown_interval : Float - """ - - def __init__(self, - priority=1, - n_parallel=1, - timeout=10, - number=10, - repeat=1, - min_repeat_ms=0, - cooldown_interval=0.0): - ctx = tvm.context("cuda", 0) - if ctx.exist: - cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - set_cuda_target_arch(cuda_arch) - host = '0.0.0.0' - self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % self.tracker.port - self.server = Server(host, port=self.tracker.port, port_end=10000, - key=device_key, use_popen=True, silent=True, - tracker_addr=(self.tracker.host, self.tracker.port)) - self.runner = RPCRunner(device_key, host, self.tracker.port, priority, - n_parallel, timeout, number, repeat, - min_repeat_ms, cooldown_interval) - # wait for the processes to start - time.sleep(0.5) - - def __del__(self): - self.server.terminate() - self.tracker.terminate() - class MeasureErrorNo(object): """Error type for MeasureResult""" @@ -389,103 +287,6 @@ def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: return results - -@tvm._ffi.register_func("ansor.rpc_runner.run") -def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], - key: str, host: str, port: int, priority: int, timeout: float, - n_parallel: int, number: int, repeat: int, min_repeat_ms: int, - cooldown_interval: float, verbose: int): - global global_run_arguments - global_run_arguments = (inputs, build_results, key, host, port, priority, timeout, number, - repeat, min_repeat_ms, cooldown_interval, verbose) - - assert len(inputs) == len(build_results), \ - "Measure input size should be equal to build results" - pool = NoDaemonPool(n_parallel) - tuple_res = pool.map(rpc_run_worker, range(len(build_results))) - pool.terminate() - pool.join() - del pool - - results = [] - for res in tuple_res: - results.append(MeasureResult(*res)) - - if verbose >= 1: - print("") - - return results - - -def rpc_run_worker(index): - """ ... - """ - inputs, build_results, key, host, port, priority, timeout, number, \ - repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments - - MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log - inp = inputs[index] - build_res = build_results[index] - - if build_res.error_no != MeasureErrorNo.NO_ERROR: - return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ - time.time() - - def timed_func(): - tic = time.time() - error_no = 0 - error_msg = None - try: - # upload built module - remote = request_remote(key, host, port, priority, timeout) - remote.upload(build_res.filename) - func = remote.load_module(os.path.split(build_res.filename)[1]) - ctx = remote.context(str(inp.task.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) - except Exception: - costs = (MAX_FLOAT,) - error_no = MeasureErrorNo.COMPILE_DEVICE - error_msg = make_error_msg() - - if error_no == 0: - try: - args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in - build_res.args] - ctx.sync() - - costs = time_f(*args).results - # clean up remote files - remote.remove(build_res.filename) - remote.remove(os.path.splitext(build_res.filename)[0] + '.so') - remote.remove('') - except Exception: - costs = (MAX_FLOAT,) - error_no = MeasureErrorNo.RUNTIME_DEVICE - error_msg = make_error_msg() - - shutil.rmtree(os.path.dirname(build_res.filename)) - toc = time.time() - - time.sleep(cooldown_interval) - if verbose >= 1: - if error_no == MeasureErrorNo.NO_ERROR: - print("*", end="") - else: - print("*E", end="") # Run error - - return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc - - res = call_func_with_timeout(timeout, timed_func) - - if isinstance(res, TimeoutError): - if verbose >= 1: - print("*T", end="") # Run timeout - res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \ - timeout, time.time() - return res - - @tvm._ffi.register_func("ansor.local_runner.run") def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], timeout: float, number: int, repeat: int, min_repeat_ms: int, @@ -510,7 +311,7 @@ def timed_func(inp, build_res): if error_no == 0: try: - args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 9e3c857aba36..b406824ba811 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -177,69 +177,3 @@ def func_wrapper(que): del que return res - - -def request_remote(device_key, host=None, port=None, priority=1, timeout=60): - """Request a remote session - - Parameters - ---------- - device_key: string - The device key of registered device in tracker - host: host, optional - The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this session (units: second) - - Returns - ------ - session: RPCSession - """ - # connect to the tracker - host = host or os.environ['TVM_TRACKER_HOST'] - port = port or int(os.environ['TVM_TRACKER_PORT']) - - tracker = rpc.connect_tracker(host, port) - remote = tracker.request(device_key, priority=priority, - session_timeout=timeout) - return remote - - -def check_remote(device_key, host=None, port=None, priority=100, timeout=10): - """ - Check the availability of a remote device - - Parameters - ---------- - device_key: string - device key of registered device in tracker - host: host, optional - The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this check (units: seconds). - - Returns - ------- - available: bool - True if can find available device - """ - - def _check(): - remote = request_remote(device_key, host, port, priority) - - t = threading.Thread(target=_check, ) - t.start() - t.join(timeout) - return not t.is_alive() diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 05cb95c2c451..82ec07930adc 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -26,7 +26,6 @@ #include #include #include -#include "search_policy/sketch_search_policy.h" namespace tvm { namespace ansor { diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index d7af8b94729a..7638f98e65ea 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -37,7 +37,6 @@ #include #include #include "transform_step.h" -#include "search_policy/utils.h" namespace tvm { namespace ansor { @@ -599,323 +598,6 @@ std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } -class IndexRewriter : public StmtExprMutator { - public: - IndexRewriter(const OperationMap >& placeholder_new_names, - const OperationMap >& placeholder_new_shapes): - placeholder_new_names_(placeholder_new_names), - placeholder_new_shapes_(placeholder_new_shapes) {} - - PrimExpr Rewrite(PrimExpr expr) { - return this->VisitExpr(expr); - } - - PrimExpr VisitExpr_(const ProducerLoadNode* op) final { - te::Tensor t = Downcast(op->producer); - auto it = placeholder_new_names_.find(t->op); - if (it != placeholder_new_names_.end()) { - const std::vector& new_names = it->second; - const Array& new_shape = placeholder_new_shapes_.at(t->op); - std::unordered_map name_to_arg; - for (const auto& arg : op->indices) { - std::string axis_name; - if (const auto* pimm = arg.as()) { - CHECK_EQ(pimm->value, 0); - axis_name = "IntImm"; - } else { - axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); - CHECK_EQ(name_to_arg.count(axis_name), 0); - name_to_arg[axis_name] = arg; - } - } - - std::unordered_map div_factors; - std::vector r_new_args; - for (int i = new_names.size() - 1; i >= 0; --i) { - auto ori_iter_name = new_names[i]; - auto name_it = name_to_arg.find(ori_iter_name); - CHECK(name_it != name_to_arg.end()); - PrimExpr ori_arg = name_it->second; - - PrimExpr mod_factor = new_shape[i]; - - PrimExpr div_factor = 1; - if (div_factors.count(ori_iter_name)) { - div_factor = div_factors[ori_iter_name]; - } - div_factors[ori_iter_name] = div_factor * new_shape[i]; - - PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); - - r_new_args.push_back(new_arg); - } - - Array new_args(std::make_move_iterator(r_new_args.rbegin()), - std::make_move_iterator(r_new_args.rend())); - - return ProducerLoad(op->producer, new_args); - } - return GetRef(op); - } - - private: - const OperationMap >& placeholder_new_names_; - const OperationMap >& placeholder_new_shapes_; -}; - -void ComputeDAG::RewriteLayout( - const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { - ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); - const State& state = ReplayAndInferBound(transform_steps); - - OperationMap > placeholder_new_names; - OperationMap > placeholder_new_shapes; - int stage_id = -1; - for (const auto& stage : state->stages) { - stage_id += 1; - const te::Operation& op = stage->op; - if (op->IsInstance()) { - const Map& attrs = op->attrs; - if (attrs.count(layout_free_placeholders_key)) { - const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; - Array placeholders = Downcast>(attr_value); - for (auto& placeholder : placeholders) { - const auto placeholder_op = placeholder->op; - - // Check whether this placeholder has already been handled - if (placeholder_new_names.count(placeholder_op)) { - continue; - } - - // skip the op that is not direct consumer of this placeholder, - // mostly due to cache read/write. - bool direct_consumer = false; - for (auto& t : op->InputTensors()) { - if (t->op == placeholder_op) { - direct_consumer = true; - break; - } - } - if (!direct_consumer) { - continue; - } - - std::set placeholder_axis_names; - TensorAccessExtractor extractor; - for (const auto& exp : op.as()->body) { - extractor.Extract(exp); - } - bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || - layout_rewrite_level == kBothRewrite); - bool rewrite_body = (layout_rewrite_level == kComputeRewrite || - layout_rewrite_level == kBothRewrite); - std::ostringstream os; - - uint i = 0; - if (extractor.buf_accesses.count(placeholder_op)) { - for (const auto& ev : extractor.buf_accesses[placeholder_op]) { - for (const auto& e : ev) { - // TODO(minminsun): check whether the extents match the shape of placeholder - std::string axis_name; - if (const auto* pimm = e.as()) { - CHECK_EQ(pimm->value, 0); - // CHECK_EQ(placeholder->shape[i].as()->value, 1); - axis_name = "IntImm"; - } else { - axis_name = BaseName(CleanName(Downcast(e)->name_hint)); - } - - placeholder_axis_names.insert(axis_name); - if (rewrite_placeholder) { - os << placeholder->shape[i++] << axis_name; - } - } - } - - if (rewrite_placeholder) { - CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); - std::string ori_layout = os.str(); - os.str(""); - // ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); - } - } - - std::vector stage_iters; - - auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); - int attach_pos = -1; - size_t iters_before_attach = 0; - if (attach_it != state->attach_map->stage_to_attach_iter.end()) { - auto attach = attach_it->second; - const auto& attach_stage = state->stages[attach.first]; - attach_pos = attach.second; - stage_iters.insert(stage_iters.end(), - attach_stage->iters.begin(), - attach_stage->iters.begin() + attach_pos + 1); - } - - stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); - - std::vector iters; - for (size_t i = 0; i < stage_iters.size(); ++i) { - const auto& iter = stage_iters[i]; - if (iter->ori_iters.empty()) { - iters.push_back(iter); - } else { - for (const Iterator& ori_iter : iter->ori_iters) { - iters.push_back(ori_iter); - } - } - if (static_cast(i) == attach_pos) { - iters_before_attach = iters.size(); - } - } - - std::vector new_names; - Array new_shape; - std::vector new_axis_names; - for (const Iterator& iter : iters) { - std::set ori_iter_names; - ExtractOriginalIterators(iter->name, &ori_iter_names); - // fused iters have been replaced with iter->ori_iters. - // So there should be only one ori iter name extracted from iter->name. - CHECK_EQ(ori_iter_names.size(), 1); - auto ori_iter_name = BaseName(*ori_iter_names.begin()); - new_axis_names.push_back(ori_iter_name); - } - for (size_t i = 0; i < new_axis_names.size(); ++i) { - auto iter = iters[i]; - std::string ori_iter_name; - if (i < iters_before_attach) { - ori_iter_name = new_axis_names[i + iters_before_attach]; - } else { - ori_iter_name = new_axis_names[i]; - } - if (placeholder_axis_names.count(ori_iter_name)) { - os << iter->range->extent << ori_iter_name; - new_names.push_back(ori_iter_name); - new_shape.push_back(iter->range->extent); - } - } - std::string new_layout = os.str(); - os.str(""); - // ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); - placeholder_new_names[placeholder_op] = new_names; - placeholder_new_shapes[placeholder_op] = new_shape; - - Array old_ops = pdag->ops; - ArrayNode* pops = pdag->ops.CopyOnWrite(); - - // Create new placeholder - te::Operation new_placeholder_op; - if (rewrite_placeholder) { - new_placeholder_op = - te::PlaceholderOp(placeholder_op->name, - new_shape, - placeholder_op.as()->dtype); - } else { - new_placeholder_op = placeholder_op; - } - - te::Operation new_compute_op, old_compute_op; - if (rewrite_body) { - Array new_body; - IndexRewriter index_rewriter(placeholder_new_names, - placeholder_new_shapes); - for (auto& op : old_ops) { - if (auto* pop = op.as()) { - bool need_update = false; - for (auto& t : op->InputTensors()) { - if (t->op == placeholder_op) { - need_update = true; - break; - } - } - if (need_update) { - for (auto& body : pop->body) { - new_body.push_back(index_rewriter.Rewrite(body)); - } - old_compute_op = op; - CHECK(!new_compute_op.defined()); - new_compute_op = te::ComputeOp( - pop->name, pop->tag, pop->attrs, pop->axis, new_body); - } - } - } - } - - // construct the map from old_op to new_op - std::unordered_map updated_ops; - for (size_t i = 0; i < old_ops.size(); ++i) { - auto old_op = old_ops[i]; - if (rewrite_placeholder && old_op == placeholder_op) { - pops->SetItem(i, new_placeholder_op); - updated_ops[placeholder_op] = new_placeholder_op; - } else if (rewrite_body && old_op == old_compute_op) { - pops->SetItem(i, new_compute_op); - updated_ops[old_compute_op] = new_compute_op; - } else { - pops->SetItem(i, old_op); - } - } - - // Because ops is sorted in topo-order, only do one pass linear scan here. - for (size_t i = 0; i < pops->size(); ++i) { - auto old_op = Downcast(pops->at(i)); - if (auto* pop = old_op.as()) { - auto inputs = pop->InputTensors(); - std::unordered_map rmap; - for (auto input : inputs) { - auto it = updated_ops.find(input->op); - te::Operation new_op; - while (it != updated_ops.end()) { - new_op = it->second; - it = updated_ops.find(new_op); - } - if (new_op.defined()) { - int index = input->value_index; - rmap[input] = new_op.output(index); - } - } - if (!rmap.empty()) { - te::Operation new_op = pop->ReplaceInputs(old_op, rmap); - updated_ops[old_op] = new_op; - pops->SetItem(i, new_op); - } - } - } - - pdag->init_state = State(pdag->ops); - - Array old_tensors = pdag->tensors; - ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); - - for (size_t i = 0; i < old_tensors.size(); ++i) { - const auto& old_tensor = old_tensors[i]; - auto it = updated_ops.find(old_tensor->op); - te::Operation new_op; - while (it != updated_ops.end()) { - new_op = it->second; - it = updated_ops.find(new_op); - } - if (new_op.defined()) { - if (layout_rewrite_level == kBothRewrite) { - auto index = old_tensor->value_index; - ptensors->SetItem(i, new_op.output(index)); - } else if (layout_rewrite_level == kComputeRewrite) { - te::TensorNode* old_tensor_node = - const_cast(old_tensor.as()); - old_tensor_node->op = new_op; - } - } - } - } // end for placeholder - } - } - } // end for stage -} - - void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { if (auto pop = stage->op.as()) { std::vector& axes = (*stage_to_axes)[stage]; @@ -938,13 +620,7 @@ std::pair > ComputeDAG::ApplySteps( LayoutRewriteLevel layout_rewrite_level) const { std::vector stages; StageToAxesMap stage_to_axes; - if (layout_rewrite_level != kNoRewrite && !transform_steps.empty()) { - ComputeDAG new_dag = *this; - new_dag.RewriteLayout(transform_steps, layout_rewrite_level); - return new_dag.ReplaySteps(transform_steps, &stages, &stage_to_axes); - } else { - return ReplaySteps(transform_steps, &stages, &stage_to_axes); - } + return ReplaySteps(transform_steps, &stages, &stage_to_axes); } std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_steps) const { @@ -1135,32 +811,8 @@ std::pair > ComputeDAG::ReplaySteps( ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, &schedule); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, &schedule); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, &schedule); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } @@ -1270,15 +922,6 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAG") TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") .set_body_method(&ComputeDAG::GetInitState); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGRewriteLayoutFromState") -.set_body([](TVMArgs args, TVMRetValue *ret) { - ComputeDAG dag = args[0]; - State state = args[1]; - - dag.RewriteLayout(state->transform_steps, kPlaceholderRewrite); - *ret = dag; -}); - TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body([](TVMArgs args, TVMRetValue *ret) { ComputeDAG dag = args[0]; diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index b1b60e678904..2f1330d612dd 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -148,10 +148,6 @@ class ComputeDAG: public ObjectRef { const std::vector& transform_steps, LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; - // Rewrite the the layout of "layout free" placeholders according to transform steps - void RewriteLayout(const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; - // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc deleted file mode 100644 index ee7bf8b26053..000000000000 --- a/src/ansor/cost_model/cost_model.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/cost_model.h - * \brief Cost model that estimates the performance of programs - */ - -#include "cost_model.h" - -#include -#include - -#include - -namespace tvm { -namespace ansor { - -using ::tvm::runtime::NDArray; - -TVM_REGISTER_OBJECT_TYPE(CostModelNode); -TVM_REGISTER_OBJECT_TYPE(RandomModelNode); -TVM_REGISTER_OBJECT_TYPE(MeasureModelNode); -TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode); - -void RandomNumber(TVMArgs args, TVMRetValue* rv) { - int n = args[0]; - void* data = args[1]; - float* fdata = reinterpret_cast(data); - for (int i = 0; i < n; i++) { - fdata[i] = static_cast(rand_r(nullptr)) / (static_cast(RAND_MAX)); - } -} - -RandomModel::RandomModel() { - ObjectPtr node = make_object(); - node->random_number_func = - runtime::Registry::Get("ansor.cost_model.random_number"); - if (node->random_number_func == nullptr) { - LOG(WARNING) << "ansor.cost_model.random_number is not registered, " - << "use C++ default random_number func instead."; - static PackedFunc cost_model_random_number(RandomNumber); - node->random_number_func = &cost_model_random_number; - } - data_ = std::move(node); -} - -void RandomModelNode::Update(const Array& inputs, - const Array& results) {} - -void RandomModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { - scores->resize(states.size()); - (*random_number_func)(states.size(), static_cast(scores->data())); -} - -MeasureModel::MeasureModel(Builder builder, Runner runner) { - ObjectPtr node = make_object(); - node->measurer = ProgramMeasurer(std::move(builder), std::move(runner), - Array(), 0); - data_ = std::move(node); -} - -void MeasureModelNode::Update(const Array& inputs, - const Array& results) {} - -void MeasureModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { - std::vector inputs; - std::vector results; - - inputs.clear(); - inputs.reserve(states.size()); - for (const auto& state : states) { - inputs.push_back(MeasureInput(task, state)); - } - measurer->SilentMeasure(task, inputs, &results); - - scores->clear(); - scores->reserve(results.size()); - for (const auto& res : results) { - scores->push_back(1.0 / FloatArrayMean(res->costs)); - } -} - -PythonBasedModel::PythonBasedModel(PackedFunc update_func, - PackedFunc predict_func, - PackedFunc predict_stage_func) { - auto node = make_object(); - node->update_func = std::move(update_func); - node->predict_func = std::move(predict_func); - node->predict_stage_func = std::move(predict_stage_func); - data_ = std::move(node); -} - -void PythonBasedModelNode::Update(const Array& inputs, - const Array& results) { - update_func(inputs, results); -} - -void PythonBasedModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { - scores->resize(states.size()); - predict_func(task, Array(states.begin(), states.end()), - static_cast(scores->data())); -} - -void PythonBasedModelNode::PredictStages(const SearchTask& task, - const std::vector& states, std::vector* state_scores, - std::vector>* stage_scores) { - int n_states = states.size(); - int n_stages = task->compute_dag.GetInitState()->stages.size(); - std::vector flatten_scores; - // Allocate sufficient spaces. - flatten_scores.resize(n_states * n_stages * 2); - predict_stage_func(task, Array(states.begin(), states.end()), - static_cast(flatten_scores.data())); - - // Unpack flatten scores. - state_scores->clear(); - stage_scores->clear(); - - // Score of each states. - for (int i = 0; i < n_states; ++i) { - state_scores->push_back(flatten_scores[i]); - } - - // Score of each stage in each states. - size_t idx = n_states; - for (int i = 0; i < n_states; ++i) { - CHECK_LE(idx, flatten_scores.size()); - - // Number of scored stages of this state. - int s_length = static_cast(flatten_scores[idx++]); - - if (s_length > 0) { - std::vector scores; - int offset = 0; - - if ((*state_scores)[i] > -INFINITY) { - // If the score is valid. Copy scored stages and assign 0 to placeholder - // and inlined stages. If the score is 0, meaning this state failed to - // be lowered. Just bypass to update offset. - for (const Stage& stage : states[i]->stages) { - if (stage->op_type == kPlaceholder) { - scores.push_back(0); - continue; - } - if (stage->compute_at == kInlined) { - scores.push_back(0); - continue; - } - scores.push_back(flatten_scores[idx + offset]); - offset++; - } - CHECK_EQ(offset, s_length); - stage_scores->push_back(std::move(scores)); - } - idx += s_length; - } else { - // Cost model does not provide any stage score details. - stage_scores->push_back({}); - } - } -} - -TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { - return RandomModel(); -}); - -TVM_REGISTER_GLOBAL("ansor.PythonBasedModel") -.set_body_typed([](PackedFunc update_func, PackedFunc predict_func, - PackedFunc predict_stage_func) { - return PythonBasedModel(update_func, predict_func, - predict_stage_func); -}); - -} // namespace ansor -} // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h deleted file mode 100644 index f38624a3572c..000000000000 --- a/src/ansor/cost_model/cost_model.h +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/cost_model.h - * \brief Cost model that estimates the performance of programs -*/ - -#ifndef TVM_ANSOR_COST_MODEL_COST_MODEL_H_ -#define TVM_ANSOR_COST_MODEL_COST_MODEL_H_ - -#include -#include -#include -#include -#include "../measure.h" - -namespace tvm { -namespace ansor { - -using runtime::PackedFunc; - -/*! \brief The base class for cost model */ -class CostModelNode: public Object { - public: - // Update the cost model according to new measurement pairs - virtual void Update(const Array& inputs, - const Array& results) = 0; - - // Predict the scores of states - virtual void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) = 0; - - // Predict the scores of all stages in states - virtual void PredictStages(const SearchTask& task, - const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { - LOG(FATAL) << "Not Implemented"; - } - - static constexpr const char *_type_key = "ansor.CostModel"; - TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); -}; -TVM_DEFINE_MUTABLE_OBJECT_REF(CostModel, CostModelNode); - -/*! \brief The cost model returns random value for all predictions */ -class RandomModelNode: public CostModelNode { - public: - const PackedFunc* random_number_func; - - void Update(const Array& inputs, - const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) final; - - static constexpr const char *_type_key = "ansor.RandomModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); -}; - -/*! - * \brief Managed reference to RandomModelNode. - * \sa RandomModelNode - */ -class RandomModel : public CostModel { - public: - RandomModel(); - explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) - : CostModel(n) {} - - RandomModelNode* operator->() const { - return static_cast(data_.get()); - } - - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel); - using ContainerType = RandomModelNode; -}; - -/*! \brief The cost model returns actual cost by measurement */ -class MeasureModelNode : public CostModelNode { - public: - ProgramMeasurer measurer; - - void Update(const Array& inputs, - const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) final; - - static constexpr const char* _type_key = "ansor.MeasureModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); -}; - -/*! - * \brief Managed reference to MeasureModelNode. - * \sa MeasureModelNode - */ -class MeasureModel : public CostModel { - public: - MeasureModel(Builder builder, Runner runner); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureModel, CostModel, - MeasureModelNode); -}; - -/*! \brief A wrapper for cost model defined by python code - * This class will call python's function */ -class PythonBasedModelNode: public CostModelNode { - public: - PackedFunc update_func; - PackedFunc predict_func; - PackedFunc predict_stage_func; - - void Update(const Array& inputs, - const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) final; - - static constexpr const char *_type_key = "ansor.PythonBasedModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); -}; - -/*! - * \brief Managed reference to PythonBasedModelNode. - * \sa PythonBasedModelNode - */ -class PythonBasedModel : public CostModel { - public: - PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, - PackedFunc predict_stage_func); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, - PythonBasedModelNode); -}; - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_COST_MODEL_COST_MODEL_H_ diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc deleted file mode 100644 index 73f6bad0d432..000000000000 --- a/src/ansor/feature.cc +++ /dev/null @@ -1,1573 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/feature.cc - * \brief Feature extraction for the cost model - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "measure.h" -#include "serialization.h" -#include "utils.h" - -namespace tvm { -/* Import the function from driver_api.cc */ -extern void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list); -} // namespace tvm - - -namespace tvm { -namespace ansor { - -using namespace tvm::tir; -using arith::ConstIntBound; -using arith::Analyzer; - -template -using BufferMap = std::unordered_map; - -static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; - -// Annotation position encoding -enum AnnotationPosType { - kPosNone, kPosInnerSpatial, kPosMiddleSpatial, kPosOuterSpatial, - kPosInnerReduce, kPosMiddleReduce, kPosOuterReduce, kPosMixed -}; - -// Buffer access type -enum BufferAccessType { - kRead, kWrite, kReadWrite, kUnknownRW -}; - -// Accesses to a buffer -struct BufferAccess { - BufferAccessType acc_type{kUnknownRW}; - std::vector > indices; -}; - -// Data reuse type -enum ReuseType { - kLoopMultipleRead, kSerialMultipleReadWrite, kNoReuse -}; - -// Feature for an access of a buffer -struct BufferAccessFeature { - std::string buffer_name; - BufferAccessType acc_type; - float bytes; - float unique_bytes; - float lines; - float unique_lines; - ReuseType reuse_type; - float reuse_dis_iter; // reuse distance in iterator number - float reuse_dis_bytes; // reuse distance in total touched bytes - float reuse_ct; // reuse times - float bytes_d_reuse_ct; - float unique_bytes_d_reuse_ct; - float lines_d_reuse_ct; - float unique_lines_d_reuse_ct; - float stride; -}; - -// Feature set of a statement -struct FeatureSet { - // compute feature - float float_mad; - float float_addsub; - float float_mul; - float float_divmod; - float float_cmp; - float float_math_func; - float float_other_func; - float int_mad; - float int_addsub; - float int_mul; - float int_divmod; - float int_cmp; - float int_math_func; - float int_other_func; - float bool_op; - float select_op; - float vec_num; // The number of vectorized iterators - float vec_prod; // The product of the lengths of vectorized iterators - float vec_len; // The length of the innermost vectorized iterator - AnnotationPosType vec_type; - float unroll_num; // The number of unrolled iterators - float unroll_prod; // The product of the lengths of vectorized iterators - float unroll_len; // The length of the innermost unrolled iterator - AnnotationPosType unroll_type; - float parallel_num; // The number of paralleled iterators - float parallel_prod; // The product of the lengths of paralleled iterators - float parallel_len; // The length of the innermost paralleled iterators - AnnotationPosType parallel_type; - float is_gpu; - float blockIdx_x_len; - float blockIdx_y_len; - float blockIdx_z_len; - float threadIdx_x_len; - float threadIdx_y_len; - float threadIdx_z_len; - float vthread_len; - - float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; - - // buffer access feature (per buffer) - std::vector access_feas; - - // allocation feature - float alloc_size; - float alloc_prod; - float alloc_outer_prod; - float alloc_inner_prod; - - // overall feature - float outer_prod; - float num_loops; - float auto_unroll_max_step; -}; - -// Return whether a var is in an expr -bool VarInExpr(const Var& var, const PrimExpr& expr) { - bool find = false; - - PostOrderVisit(expr, [&find, &var](const ObjectRef &node) { - if (find) { - return; - } - - if (const VarNode* op = node.as()) { - if (op == var.get()) { - find = true; - } - } - }); - - return find; -} - -// Get position encoding for annotation -AnnotationPosType GetAnnotationPosEncoding( - const Var& var, const Array& spatial_args, - const Array& axis, const Array& reduce_axis) { - // Try to match spatial args first - size_t find_i = 0; - size_t find_ct = 0; - for (size_t i = 0; i < spatial_args.size(); ++i) { - if (VarInExpr(var, spatial_args[i])) { - find_i = i; - find_ct += 1; - } - } - - if (find_ct == 0) { - // If not find in spacial args, then it is a reduce iterator. - // Use name to match - const std::string& var_name = var->name_hint; - for (size_t i = 0; i < reduce_axis.size(); ++i) { - if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { - find_i = i; - find_ct++; - } - } - if (find_ct >= 1) { - if (find_i == 0) { - return kPosInnerReduce; - } else if (find_i == reduce_axis.size() - 1) { - return kPosOuterReduce; - } else { - return kPosMiddleReduce; - } - } else { - // If the axis is not found in both spatial args and reduce axis, - // then this stage must compute_at somewhere under this aixs and this axis is simplified out - // We assume it is an outer spatial - return kPosOuterSpatial; - } - } else if (find_ct == 1) { - if (find_i == spatial_args.size() - 1) { - return kPosInnerSpatial; - } else if (find_i == 0) { - return kPosOuterSpatial; - } else { - return kPosMiddleSpatial; - } - } else { - return kPosMixed; - } -} - -// Count math ops in an expr -class MathOpCounter : public StmtExprVisitor { - public: -#define VisitBinary(Type, float_ct, int_ct) \ - void VisitExpr_(const Type* op) final { \ - if (op->a.dtype().is_float()) { \ - float_ct++; \ - } else { \ - int_ct++; \ - } \ - StmtExprVisitor::VisitExpr_(op); \ - } \ - - VisitBinary(AddNode, float_addsub, int_addsub); - VisitBinary(SubNode, float_addsub, int_addsub); - VisitBinary(MulNode, float_mul, int_mul); - VisitBinary(DivNode, float_divmod, int_divmod); - VisitBinary(ModNode, float_divmod, int_divmod); - VisitBinary(FloorDivNode, float_divmod, int_divmod); - VisitBinary(FloorModNode, float_divmod, int_divmod); - VisitBinary(MaxNode, float_cmp, int_cmp); - VisitBinary(MinNode, float_cmp, int_cmp); - VisitBinary(EQNode, float_cmp, int_cmp); - VisitBinary(NENode, float_cmp, int_cmp); - VisitBinary(LTNode, float_cmp, int_cmp); - VisitBinary(LENode, float_cmp, int_cmp); - VisitBinary(GTNode, float_cmp, int_cmp); - VisitBinary(GENode, float_cmp, int_cmp); - - void VisitExpr_(const AndNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const OrNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const NotNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const SelectNode* op) final { select_op++; StmtExprVisitor::VisitExpr_(op); } - - void VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::CallType::PureIntrinsic) { - if (op->dtype.is_float()) { - float_math_func++; - } else { - int_math_func++; - } - } else { - if (op->dtype.is_float()) { - float_other_func++; - } else { - int_other_func++; - } - } - StmtExprVisitor::VisitExpr_(op); - } - - // todo(lmzheng): detect mad - size_t float_mad{0}, float_addsub{0}, float_mul{0}, float_divmod{0}, - float_cmp{0}, float_math_func{0}, float_other_func{0}; - size_t int_mad{0}, int_addsub{0}, int_mul{0}, int_divmod{0}, - int_cmp{0}, int_math_func{0}, int_other_func{0}; - size_t bool_op{0}, select_op{0}; -}; - - -// Extract all buffer accesses in an expr -class BufferAccessExtractor : public StmtExprVisitor { - public: - void ExtractReads(const PrimExpr& expr) { - this->VisitExpr(expr); - } - - void InsertAccess(const Buffer& buf, BufferAccessType acc_type, - const Array& indices) { - BufferAccess& acc = buf_accesses[buf]; - acc.acc_type = acc_type; - acc.indices.push_back(std::vector(indices.begin(), indices.end())); - } - - void VisitExpr_(const BufferLoadNode *op) final { - BufferAccess& acc = buf_accesses[op->buffer]; - switch (acc.acc_type) { - case kRead: - break; - case kWrite: - acc.acc_type = kReadWrite; break; - case kReadWrite: - break; - case kUnknownRW: - default: - acc.acc_type = kRead; break; - } - - if (acc.acc_type != kReadWrite) { - // If a buffer is both read and written, in the tvm DSL, it must be a update, - // so the indices should be the same. Then we can skip appending indices for it. - // Otherwise we do the following. - buf_accesses[op->buffer].indices.push_back( - std::vector(op->indices.begin(), op->indices.end())); - } - StmtExprVisitor::VisitExpr_(op); - } - - BufferMap buf_accesses; -}; - -// Compute coefficient for an loop iterator in an expression -// Note: we use a approximation strategy to find coefficient. -// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) -class CoefficientExtractor : public StmtExprVisitor { - public: - void VisitExpr_(const MulNode *node) final { - StmtExprVisitor::VisitExpr_(node); - if (visited_var) { - if (!visited_add) { - if (auto a = node->a.as()) { - visited_mul = true; - stride = a->value; - } else if (auto b = node->b.as()) { - visited_mul = true; - stride = b->value; - } - } - } - } - - void VisitExpr_(const AddNode *node) final { - StmtExprVisitor::VisitExpr_(node); - if (visited_var) { - if (!visited_mul) { - visited_add = true; - stride = 1; - } - } - } - - void VisitExpr_(const VarNode *node) final { - if (node == var_) { - visited_var = true; - // This is a magic default stride in case our approximation strategy fails - stride = 2; - } - } - - int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { - visited_var = visited_mul = visited_add = false; - var_ = var; - - this->VisitExpr(expr); - - if (visited_var && !visited_mul && !visited_add) { - return 1; - } else { - return stride; - } - } - - bool visited_var{false}; - bool visited_mul{false}; - bool visited_add{false}; - int stride{0}; - - private: - const VarNode* var_{nullptr}; -}; - -// Compute stride for the accesses to a buffer -int64_t ComputeStride(const std::vector >& indices, - const std::vector& shape, - const VarNode* stride_var) { - int64_t min_stride = std::numeric_limits::max(); - bool find = false; - CoefficientExtractor extractor; - - for (const auto &index : indices) { - int64_t shape_stride = 1; - for (int i = static_cast(index.size()) - 1; i >= 0; i--) { - int coefficient = extractor.ExtractCoefficient(index[i], stride_var); - if (extractor.visited_var) { - find = true; - min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); - break; - } - shape_stride *= shape[i]; - } - } - - return find ? min_stride : 0; -} - -// Compute touched bytes and cache lines for accesses to a buffer -void ComputeRegion( - const std::vector > &indices, - arith::Analyzer* ana, - std::vector* region) { - region->clear(); - - if (indices.empty()) { - return; - } - - region->reserve(indices[0].size()); - - if (indices.size() == 1) { - for (const auto& index : indices[0]) { - ConstIntBound bound = ana->const_int_bound(index); - region->push_back(bound->max_value - bound->min_value + 1); - } - } else { - // future(lmzheng): implement a more accurate IntSet? - for (size_t i = 0; i < indices[0].size(); ++i) { - int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; - for (size_t j = 0; j < indices.size(); ++j) { - ConstIntBound bound = ana->const_int_bound(indices[j][i]); - - minimum = std::min(minimum, bound->min_value); - maximum = std::max(maximum, bound->max_value); - } - region->push_back(maximum - minimum + 1); - } - } -} - -// Compute reuse distance and reuse ratio for accesses to a buffer -// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct -std::tuple ComputeReuse( - const Buffer& buf, - const std::vector >& indices, - const std::vector& for_loop_stack, - const std::unordered_map > > >& for_touch_regions) { - float reuse_dis_iter = 1.0f; - float reuse_dis_bytes = -1.0f; - - for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; --i) { - const ForNode* cur_for = for_loop_stack[i]; - bool find = false; - - for (size_t j = 0; j < indices.size(); j++) { - for (size_t k = 0; k < indices[j].size(); k++) { - if (VarInExpr(cur_for->loop_var, indices[j][k])) { - find = true; - break; - } - } - if (find) { - break; - } - } - - int64_t extent = GetIntImm(for_loop_stack[i]->extent); - if (find) { - // accumulate/update reuse distance - reuse_dis_iter *= extent; - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); - } - } - } else { - // Have LoopMultipleRead reuse - if (reuse_dis_bytes < 0) { - // For the reuse in the innermost axis, the above code won't be executed. - // So we compute bytes here - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += 1 * std::get<2>(access); - } - } - } - return std::make_tuple(kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); - } - - const BufferMap > >& buffer_map - = for_touch_regions.at(cur_for); - - int serial_reuse = static_cast(buffer_map.at(buf).size()) - 1; - if (serial_reuse > 0) { - int64_t extent = GetIntImm(cur_for->extent); - - // Have SerialMultipleReadWrite reuse - reuse_dis_iter = std::numeric_limits::max(); - for (const auto& acc_info : buffer_map.at(buf)) { - reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); - } - - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); - } - } - - return std::make_tuple(kSerialMultipleReadWrite, - reuse_dis_iter / extent, reuse_dis_bytes / extent, serial_reuse); - } - } - - return std::make_tuple(kNoReuse, 0, 0, 0); -} - -// Extract features for every Provide statement -class PerStmtFeatureExtractor : public StmtExprVisitor { - public: - explicit PerStmtFeatureExtractor(int cache_line_size) : - cache_line_size_(cache_line_size) {} - - void VisitStmt_(const AttrStmtNode* node) final { - if (node->attr_key == tir::attr::thread_extent || - node->attr_key == tir::attr::virtual_thread) { - const Var& var = node->node.as()->var; - int extent = GetIntImm(node->value); - - int* plen = nullptr; - - const std::string& name = var.get()->name_hint; - if (node->attr_key == tir::attr::thread_extent) { - if (name == "blockIdx.x") { - plen = &blockIdx_x_len; - } else if (name == "blockIdx.y") { - plen = &blockIdx_y_len; - } else if (name == "blockIdx.z") { - plen = &blockIdx_z_len; - } else if (name == "threadIdx.x") { - plen = &threadIdx_x_len; - } else if (name == "threadIdx.y") { - plen = &threadIdx_y_len; - } else if (name == "threadIdx.z") { - plen = &threadIdx_z_len; - } else { - LOG(FATAL) << "invalid thread itervar " + name; - } - } else { - plen = &vthread_len; - } - - int extent_before = *plen; - if (node->attr_key == tir::attr::thread_extent) { - *plen = extent; - } else { - *plen *= extent; - } - - is_gpu = true; - - // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, - DeviceAPI::None, node->body); - - outer_loop_prod *= extent; - for_loop_stack.push_back(fake_for_node.as()); - StmtExprVisitor::VisitStmt_(node); - for_loop_stack.pop_back(); - outer_loop_prod /= extent; - - *plen = extent_before; - } else if (node->attr_key == "pragma_auto_unroll_max_step") { - int value = GetIntImm(node->value); - - int16_t old_value = cur_auto_unroll_max_step; - cur_auto_unroll_max_step = value; - StmtExprVisitor::VisitStmt_(node); - cur_auto_unroll_max_step = old_value; - } else { - StmtExprVisitor::VisitStmt_(node); - } - } - - void VisitStmt_(const ForNode* node) final { - int64_t loop_extent = GetIntImm(node->extent); - - if (node->for_type == ForType::Vectorized) { - vec_for_stack.push_back(node); - } else if (node->for_type == ForType::Unrolled) { - unroll_for_stack.push_back(node); - } else if (node->for_type == ForType::Parallel) { - parallel_for_stack.push_back(node); - } - - outer_loop_prod *= loop_extent; - for_loop_stack.push_back(node); - StmtExprVisitor::VisitStmt_(node); - for_loop_stack.pop_back(); - outer_loop_prod /= loop_extent; - - if (node->for_type == ForType::Vectorized) { - vec_for_stack.pop_back(); - } else if (node->for_type == ForType::Unrolled) { - unroll_for_stack.pop_back(); - } else if (node->for_type == ForType::Parallel) { - parallel_for_stack.pop_back(); - } - } - - void VisitStmt_(const BufferStoreNode* node) final { - FeatureSet &fea = buffer_features[node->buffer]; - - // compute feature - MathOpCounter mathops; - mathops(node->value); - fea.float_mad = outer_loop_prod * mathops.float_mad; - fea.float_addsub = outer_loop_prod * mathops.float_addsub; - fea.float_mul = outer_loop_prod * mathops.float_mul; - fea.float_divmod = outer_loop_prod * mathops.float_divmod; - fea.float_cmp = outer_loop_prod * mathops.float_cmp; - fea.float_math_func = outer_loop_prod * mathops.float_math_func; - fea.float_other_func = outer_loop_prod * mathops.float_other_func; - fea.int_mad = outer_loop_prod * mathops.int_mad; - fea.int_addsub = outer_loop_prod * mathops.int_addsub; - fea.int_mul = outer_loop_prod * mathops.int_mul; - fea.int_divmod = outer_loop_prod * mathops.int_divmod; - fea.int_math_func = outer_loop_prod * mathops.int_math_func; - fea.int_cmp = outer_loop_prod * mathops.int_cmp; - fea.int_other_func = outer_loop_prod * mathops.int_other_func; - fea.bool_op = outer_loop_prod * mathops.bool_op; - fea.select_op = outer_loop_prod * mathops.select_op; - - fea.outer_prod = outer_loop_prod; - fea.num_loops = for_loop_stack.size(); - fea.auto_unroll_max_step = cur_auto_unroll_max_step; - fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; - fea.vec_type = fea.unroll_type = fea.parallel_type = kPosNone; - - fea.vec_num = vec_for_stack.size(); - if (!vec_for_stack.empty()) { - fea.vec_len = GetIntImm(vec_for_stack.back()->extent); - fea.vec_prod = 1.0; - for (const ForNode* pfor : vec_for_stack) { - fea.vec_prod *= GetIntImm(pfor->extent); - } - fea.vec_type = kPosMixed; - // todo(lmzheng): this feature requires operation (tvm.compute) information - // GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - fea.unroll_num = unroll_for_stack.size(); - if (!unroll_for_stack.empty()) { - fea.unroll_len = GetIntImm(unroll_for_stack.back()->extent); - fea.unroll_prod = 1.0; - for (const ForNode* pfor : unroll_for_stack) { - fea.unroll_prod *= GetIntImm(pfor->extent); - } - fea.unroll_type = kPosMixed; - // GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - fea.parallel_num = parallel_for_stack.size(); - if (!parallel_for_stack.empty()) { - fea.parallel_len = GetIntImm(parallel_for_stack.back()->extent); - fea.parallel_prod = 1.0; - for (const ForNode* pfor : parallel_for_stack) { - fea.parallel_prod *= GetIntImm(pfor->extent); - } - fea.parallel_type = kPosMixed; - // GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - // GPU threads - fea.is_gpu = is_gpu; - fea.blockIdx_x_len = blockIdx_x_len; - fea.blockIdx_y_len = blockIdx_y_len; - fea.blockIdx_z_len = blockIdx_z_len; - fea.threadIdx_x_len = threadIdx_x_len; - fea.threadIdx_y_len = threadIdx_y_len; - fea.threadIdx_z_len = threadIdx_z_len; - fea.vthread_len = vthread_len; - - // Extract all buffer access - std::vector acc_feas; - BufferAccessExtractor buf_extractor; - buf_extractor.InsertAccess(node->buffer, kWrite, node->indices); - buf_extractor.ExtractReads(node->value); - - // Compute touched region for all outer loops - Analyzer ana; - for (auto x : for_loop_stack) { - ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1), true); - } - - std::vector mem_bytes_list; - std::vector compute_ops_list; - - mem_bytes_list.reserve(for_loop_stack.size()); - compute_ops_list.reserve(for_loop_stack.size()); - - int cur_compute_ops = mathops.float_mad + mathops.float_addsub + mathops.float_mul + - mathops.float_divmod + mathops.float_cmp + - mathops.float_math_func + mathops.float_other_func; - - std::vector tmp_region; - for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { - const ForNode* p_for = for_loop_stack[i]; - - ana.Bind(p_for->loop_var, - Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent), true); - - // Note, here we do overwrite. - // So if there are multiple Provides, the last one will overwrite the first few. - // e.g. The update part in gemm will overwrite the init part. - BufferMap > >& - buffer_regions_map = for_touch_regions[p_for]; - - int64_t mem_bytes = 0; - for (const auto &x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; - const BufferAccess& acc = x.second; - - ComputeRegion(acc.indices, &ana, &tmp_region); - int64_t touched_size = ElementProduct(tmp_region); - buffer_regions_map[t].push_back(std::make_tuple(acc.acc_type, - touched_size, t->dtype.bytes())); - mem_bytes += touched_size * t->dtype.bytes(); - } - - mem_bytes_list.push_back(std::log2(mem_bytes)); - cur_compute_ops *= GetIntImm(for_loop_stack[i]->extent); - compute_ops_list.push_back(std::log2(cur_compute_ops)); - } - - // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). - // We use piecewise linear interpolation to fit this curve. - int pt = 0; - if (cur_compute_ops <= 0 || compute_ops_list.empty()) { - std::fill(fea.arith_intensity_curve, - fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); - } else { - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - float cur_compute_ops = compute_ops_list.back() * (i+1) / ARITH_INTENSITY_CURVE_SAMPLE_N; - while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { - pt++; - } - CHECK_LT(pt, compute_ops_list.size()); - - float value; - if (pt == 0) { - value = compute_ops_list[pt] / mem_bytes_list[pt]; - } else { - float base = compute_ops_list[pt-1] / mem_bytes_list[pt-1]; - float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - - compute_ops_list[pt-1] / mem_bytes_list[pt-1]) / - (compute_ops_list[pt] - compute_ops_list[pt-1]); - value = base + slope * (cur_compute_ops - compute_ops_list[pt-1]); - } - fea.arith_intensity_curve[i] = value; - } - } - - // Compute buffer access feature - for (const auto &x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; - const BufferAccess& acc = x.second; - - std::vector int_shape; - for (const auto& dim : t->shape) { - int_shape.push_back(GetIntImm(dim)); - } - - size_t ele_bytes = t->dtype.bytes(); - - // calculate bytes - float bytes = outer_loop_prod * ele_bytes; - float unique_bytes; - - // calculate cache lines - int64_t stride; - float lines; - float unique_lines; - - if (for_loop_stack.empty()) { - unique_bytes = ele_bytes; - stride = 0; - lines = 1.0f; - unique_lines = 1.0f; - } else { - unique_bytes = std::get<1>(for_touch_regions[for_loop_stack.front()][t].front()) - * ele_bytes; - - stride = 0; - int64_t reduce_ratio = 1; - - int i; - for (i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { - stride = ComputeStride(acc.indices, int_shape, for_loop_stack[i]->loop_var.get()); - if (stride != 0) { - break; - } - reduce_ratio *= GetIntImm(for_loop_stack.back()->extent); - } - - lines = outer_loop_prod / reduce_ratio * - std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); - lines = std::max(lines, 1.0f); - - // convert `stride` back to the stride of the innermost iterator - stride = (i == static_cast(for_loop_stack.size()) - 1 ? stride : 0); - - float n_continuous = ele_bytes; - for (int i = static_cast(tmp_region.size()) - 1; i >= 0; i--) { - if (tmp_region[i] == int_shape[i]) { - n_continuous *= tmp_region[i]; - break; - } - } - unique_lines = unique_bytes / std::min(n_continuous, - static_cast(cache_line_size_)); - unique_lines = std::max(unique_lines, 1.0f); - } - - ReuseType reuse_type; - float reuse_dis_iter, reuse_dis_bytes, reuse_ct; - std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = - ComputeReuse(t, acc.indices, for_loop_stack, for_touch_regions); - - acc_feas.emplace_back(); - BufferAccessFeature& acc_fea = acc_feas.back(); - - acc_fea.buffer_name = t->name; - acc_fea.acc_type = acc.acc_type; - acc_fea.stride = stride; - acc_fea.bytes = bytes; - acc_fea.unique_bytes = unique_bytes; - acc_fea.lines = lines; - acc_fea.unique_lines = unique_lines; - acc_fea.reuse_type = reuse_type; - acc_fea.reuse_dis_iter = reuse_dis_iter; - acc_fea.reuse_dis_bytes = reuse_dis_bytes; - acc_fea.reuse_ct = reuse_ct; - if (acc_fea.reuse_ct > 0.5) { - acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; - acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; - acc_fea.lines_d_reuse_ct = lines / reuse_ct; - acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; - } else { - // no reuse, multiply by a magic number '2' - acc_fea.bytes_d_reuse_ct = bytes * 2; - acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; - acc_fea.lines_d_reuse_ct = lines * 2; - acc_fea.unique_lines_d_reuse_ct = unique_lines* 2; - } - } - - fea.access_feas = acc_feas; - } - - void VisitStmt_(const BufferRealizeNode *node) final { - StmtExprVisitor::VisitStmt_(node); - - FeatureSet& fea = buffer_features[node->buffer]; - - float allocation_size = 1.0f; - for (const auto& x : node->bounds) { - allocation_size *= GetIntImm(x->extent); - } - // allocation feature - fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); - fea.alloc_prod = allocation_size * outer_loop_prod; - fea.alloc_outer_prod = outer_loop_prod; - fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod; - } - - float outer_loop_prod = 1.0f; - - std::vector for_loop_stack; - std::vector parallel_for_stack; - std::vector vec_for_stack; - std::vector unroll_for_stack; - - bool is_gpu; - int blockIdx_x_len{1}; - int blockIdx_y_len{1}; - int blockIdx_z_len{1}; - int threadIdx_x_len{1}; - int threadIdx_y_len{1}; - int threadIdx_z_len{1}; - int vthread_len{1}; - int16_t cur_auto_unroll_max_step{0}; - - BufferMap buffer_features; - - // for a loop, for all its touched buffers, for all different accesses to the buffers, - // its (access type, number of touched elements, number of bytes of single element) - std::unordered_map > > > for_touch_regions; - - private: - const int cache_line_size_ = 64; -}; - -// shifted log to incorporate the property that slog(0) = 0 -inline float slog(float x) { - return x < 0 ? -std::log2(-x+1) : std::log2(x+1); -} - -// Get features for all ir::Provide statements in a TVM program. -// So we call it `PerStmt` feature -void GetPerStmtFeature(const Stmt& stmt, - int cache_line_size, - int max_n_bufs, - std::vector* ret) { - PerStmtFeatureExtractor extractor(cache_line_size); - extractor(stmt); - - ret->push_back(extractor.buffer_features.size()); - - for (const auto& x : extractor.buffer_features) { - const FeatureSet& fea_set = x.second; - - /***** compute feature *****/ - ret->push_back(slog(fea_set.float_mad)); - ret->push_back(slog(fea_set.float_addsub)); - ret->push_back(slog(fea_set.float_mul)); - ret->push_back(slog(fea_set.float_divmod)); - ret->push_back(slog(fea_set.float_cmp)); - ret->push_back(slog(fea_set.float_math_func)); - ret->push_back(slog(fea_set.float_other_func)); - ret->push_back(slog(fea_set.int_mad)); - ret->push_back(slog(fea_set.int_addsub)); - ret->push_back(slog(fea_set.int_mul)); - ret->push_back(slog(fea_set.int_divmod)); - ret->push_back(slog(fea_set.int_cmp)); - ret->push_back(slog(fea_set.int_math_func)); - ret->push_back(slog(fea_set.int_other_func)); - ret->push_back(slog(fea_set.bool_op)); - ret->push_back(slog(fea_set.select_op)); - - ret->push_back(slog(fea_set.vec_num)); - ret->push_back(slog(fea_set.vec_prod)); - ret->push_back(slog(fea_set.vec_len)); - for (int i = 0; i <= kPosMixed; i++) { - ret->push_back(i == fea_set.vec_type); - } - - ret->push_back(slog(fea_set.unroll_num)); - ret->push_back(slog(fea_set.unroll_prod)); - ret->push_back(slog(fea_set.unroll_len)); - for (int i = 0; i <= kPosMixed; i++) { - ret->push_back(i == fea_set.unroll_type); - } - - ret->push_back(slog(fea_set.parallel_num)); - ret->push_back(slog(fea_set.parallel_prod)); - ret->push_back(slog(fea_set.parallel_len)); - for (int i = 0; i <= kPosMixed; i++) { - ret->push_back(i == fea_set.parallel_type); - } - - ret->push_back(fea_set.is_gpu); - ret->push_back(slog(fea_set.blockIdx_x_len)); - ret->push_back(slog(fea_set.blockIdx_y_len)); - ret->push_back(slog(fea_set.blockIdx_z_len)); - ret->push_back(slog(fea_set.threadIdx_x_len)); - ret->push_back(slog(fea_set.threadIdx_y_len)); - ret->push_back(slog(fea_set.threadIdx_z_len)); - ret->push_back(slog(fea_set.vthread_len)); - - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - ret->push_back(fea_set.arith_intensity_curve[i]); - } - - /***** access feature *****/ - // sort according to pair (lines, bytes) - std::vector > buf_order_key; - for (const auto& acc_fea : fea_set.access_feas) { - buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); - } - std::vector buf_order(buf_order_key.size()); - std::iota(buf_order.begin(), buf_order.end(), 0); - - auto cmp = [&buf_order_key](int l, int r) { - return buf_order_key[l].first > buf_order_key[r].first - || (buf_order_key[l].first == buf_order_key[r].first - && buf_order_key[l].second > buf_order_key[r].second); - }; - std::sort(buf_order.begin(), buf_order.end(), cmp); - int n_bufs = std::min(max_n_bufs, static_cast(buf_order.size())); - buf_order.resize(n_bufs); - - for (int idx : buf_order) { - const auto& acc_fea = fea_set.access_feas[idx]; - for (int j = 0; j <= kReadWrite; ++j) { - ret->push_back(j == acc_fea.acc_type); - } - ret->push_back(slog(acc_fea.bytes)); - ret->push_back(slog(acc_fea.unique_bytes)); - ret->push_back(slog(acc_fea.lines)); - ret->push_back(slog(acc_fea.unique_lines)); - for (int j = 0; j <= kNoReuse; ++j) { - ret->push_back(acc_fea.reuse_type == j); - } - ret->push_back(slog(acc_fea.reuse_dis_iter)); - ret->push_back(slog(acc_fea.reuse_dis_bytes)); - ret->push_back(slog(acc_fea.reuse_ct)); - ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); - ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); - ret->push_back(slog(acc_fea.lines_d_reuse_ct)); - ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); - ret->push_back(slog(acc_fea.stride)); - } - // - fill padding - for (int i = 0; i < max_n_bufs - n_bufs; ++i) { - for (int j = 0; j <= kReadWrite; ++j) { // 3 - ret->push_back(0.0f); - } - ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); - for (int j = 0; j <= kNoReuse; ++j) { // 3 - ret->push_back(0.0f); - } - ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); - ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); - } - - /***** allocation feature *****/ - ret->push_back(slog(fea_set.alloc_size)); - ret->push_back(slog(fea_set.alloc_prod)); - ret->push_back(slog(fea_set.alloc_outer_prod)); - ret->push_back(slog(fea_set.alloc_inner_prod)); - - /***** overall feature *****/ - ret->push_back(slog(fea_set.outer_prod)); - ret->push_back(slog(fea_set.num_loops)); - ret->push_back(slog(fea_set.auto_unroll_max_step)); - } -} - - -/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ -void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret) { - /***** compute feature *****/ - ret->push_back(("float_mad")); - ret->push_back(("float_addsub")); - ret->push_back(("float_mul")); - ret->push_back(("float_divmod")); - ret->push_back(("float_cmp")); - ret->push_back(("float_mathfunc")); - ret->push_back(("float_otherfunc")); - ret->push_back(("int_mad")); - ret->push_back(("int_addsub")); - ret->push_back(("int_mul")); - ret->push_back(("int_divmod")); - ret->push_back(("int_cmp")); - ret->push_back(("int_mathfunc")); - ret->push_back(("int_otherfunc")); - ret->push_back(("bool_op")); - ret->push_back(("select_op")); - ret->push_back(("vec_num")); - ret->push_back(("vec_prod")); - ret->push_back(("vec_len")); - ret->push_back(("vec_type.kPosNone")); - ret->push_back(("vec_type.kPosInnerSpatial")); - ret->push_back(("vec_type.kPosMiddleSpatial")); - ret->push_back(("vec_type.kPosOuterSpatial")); - ret->push_back(("vec_type.kPosInnerReduce")); - ret->push_back(("vec_type.kPosMiddleReduce")); - ret->push_back(("vec_type.kPosOuterReduce")); - ret->push_back(("vec_type.kPosMixed")); - ret->push_back(("unroll_num")); - ret->push_back(("unroll_prod")); - ret->push_back(("unroll_len")); - ret->push_back(("unroll_type.kPosNone")); - ret->push_back(("unroll_type.kPosInnerSpatial")); - ret->push_back(("unroll_type.kPosMiddleSpatial")); - ret->push_back(("unroll_type.kPosOuterSpatial")); - ret->push_back(("unroll_type.kPosInnerReduce")); - ret->push_back(("unroll_type.kPosMiddleReduce")); - ret->push_back(("unroll_type.kPosOuterReduce")); - ret->push_back(("unroll_type.kPosMixed")); - ret->push_back(("parallel_num")); - ret->push_back(("parallel_prod")); - ret->push_back(("parallel_len")); - ret->push_back(("parallel_type.kPosNone")); - ret->push_back(("parallel_type.kPosInnerSpatial")); - ret->push_back(("parallel_type.kPosMiddleSpatial")); - ret->push_back(("parallel_type.kPosOuterSpatial")); - ret->push_back(("parallel_type.kPosInnerReduce")); - ret->push_back(("parallel_type.kPosMiddleReduce")); - ret->push_back(("parallel_type.kPosOuterReduce")); - ret->push_back(("parallel_type.kPosMixed")); - ret->push_back(("is_gpu")); - ret->push_back(("blockIdx_x_len")); - ret->push_back(("blockIdx_y_len")); - ret->push_back(("blockIdx_z_len")); - ret->push_back(("threadIdx_x_len")); - ret->push_back(("threadIdx_y_len")); - ret->push_back(("threadIdx_z_len")); - ret->push_back(("vthread_len")); - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - ret->push_back(("arith_intensity_curve_" + std::to_string(i))); - } - // section total: 55 + ARITH_INTENSITY_CURVE_SAMPLE_N = 65 - - /***** access feature *****/ - for (size_t i = 0; i < static_cast(max_n_bufs); ++i) { - std::string prefix = "B" + std::to_string(i) + "."; - ret->push_back((prefix + "acc_type.kRead")); - ret->push_back((prefix + "acc_type.kWrite")); - ret->push_back((prefix + "acc_type.kReadWrite")); - ret->push_back((prefix + "bytes")); - ret->push_back((prefix + "unique_bytes")); - ret->push_back((prefix + "lines")); - ret->push_back((prefix + "unique_lines")); - ret->push_back((prefix + "reuse_type.kLoopMultipleRead")); - ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite")); - ret->push_back((prefix + "reuse_type.kNoReuse")); - ret->push_back((prefix + "reuse_dis_iter")); - ret->push_back((prefix + "reuse_dis_bytes")); - ret->push_back((prefix + "reuse_ct")); - ret->push_back((prefix + "bytes_d_reuse_ct")); - ret->push_back((prefix + "unique_bytes_d_reuse_ct")); - ret->push_back((prefix + "lines_d_reuse_ct")); - ret->push_back((prefix + "unique_lines_d_reuse_ct")); - ret->push_back((prefix + "stride")); - } - // section total : max_n_bufs * 18 - - /***** allocation feature *****/ - ret->push_back(("alloc_size")); - ret->push_back(("alloc_prod")); - ret->push_back(("alloc_outer_prod")); - ret->push_back(("alloc_inner_prod")); - // section total : 4 - - /***** overall feature *****/ - ret->push_back(("outer_prod")); - ret->push_back(("num_loops")); - ret->push_back(("auto_unroll_max_step")); - // section total : 2 -} - -void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, - int max_n_bufs, std::vector* feature, std::atomic* error_ct) { - te::Schedule sch; - Array tensors; - - std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); - sch = sch.normalize(); - auto bounds = te::InferBound(sch); - - try { - auto stmt = te::ScheduleOps(sch, bounds, false); - Map out_binds; Array out_arg_list; - bool compact = te::VerifyCompactBuffer(stmt); - const std::string& name = "main"; - GlobalVar global_var(name); - - // Copied from driver_api.cc::lower - auto pass_ctx = tvm::transform::PassContext::Current(); - GetBinds(tensors, compact, std::unordered_map(), - &out_binds, &out_arg_list); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); - bool disable_vectorize = - pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); - bool instrument_bound_checkers = - pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - auto mod = IRModule(Map({{global_var, f}})); - - if (task->target->device_type == kDLGPU) { - auto pass_list = Array(); - // Phase 0 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - // Phase 1 - pass_list.push_back(tir::transform::NarrowDataType(32)); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop(disable_vectorize)); - pass_list.push_back(tir::transform::InjectVirtualThread()); - pass_list.push_back(tir::transform::StorageRewrite()); - pass_list.push_back(tir::transform::Simplify()); - tvm::Map gpu_params { - {"max_shared_memory_per_block", - task->hardware_params->max_shared_memory_per_block}, - {"max_local_memory_per_block", - task->hardware_params->max_registers_per_block}, - {"max_threads_per_block", - task->hardware_params->max_threads_per_block}, - {"max_vector_bytes", - task->hardware_params->vector_unit_bytes} - }; - pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params)); - const auto& optimize = tir::transform::Sequential(pass_list); - optimize(mod); - } - const auto& optimize = tir::transform::Sequential( - Array{tir::transform::Simplify()}); - mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(global_var); - CHECK(it != mod->functions.end()); - const auto& prim_func = (*it).second.as(); - GetPerStmtFeature(prim_func->body, - task->hardware_params->cache_line_bytes, - max_n_bufs, feature); - } catch (dmlc::Error &e) { - (*error_ct)++; - } -} - -void GetPerStmtFeaturesFromStates(const Array& states, - const SearchTask& task, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features) { - // extract features - features->assign(states.size(), std::vector()); - - std::atomic error_ct(0); - - ThreadPool& pool = ThreadPool::Global(); - pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], - max_n_bufs, &(*features)[i], &error_ct); - // GetPerStmtFeaturesWorkerFunc(task, states[i], - // max_n_bufs, &(*features)[i], &error_ct); - } - pool.WaitBatch(); - - if (error_ct > 0) { - std::cerr << "Encountered " << error_ct - << " errors during feature extraction. Ignored." << std::endl; - } -} - - -void GetPerStmtFeaturesFromStates(const Array& states, - const std::vector& tasks, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features) { - // extract features - features->assign(states.size(), std::vector()); - - std::atomic error_ct(0); - - ThreadPool& pool = ThreadPool::Global(); - pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - pool.Enqueue(GetPerStmtFeaturesWorkerFunc, tasks[i], states[i], - max_n_bufs, &(*features)[i], &error_ct); - } - pool.WaitBatch(); - - if (error_ct > 0) { - std::cerr << "Encountered " << error_ct - << " errors during feature extraction. Ignored." << std::endl; - } -} - -void GetPerStmtFeaturesFromFile(const std::string& filename, - int n_lines, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids) { - Array states; - // ArrayNode* pstates = states.CopyOnWrite(); - std::vector tasks; - - normalized_throughputs->clear(); - task_ids->clear(); - - // (workload_key, target) -> (search_task, task_id) - std::unordered_map, std::pair> task_cache; - // task_id -> min_cost - std::vector min_costs; - - // read from file - LogReader reader = LogReader(filename); - auto cur_inp = make_object(); - auto cur_res = make_object(); - while (reader->ReadNext(cur_inp.get(), cur_res.get())) { - float cost = static_cast(FloatArrayMean(cur_res->costs)); - const std::string& workload_key = cur_inp->task->workload_key; - - SearchTask task; - size_t task_id; - std::pair key(workload_key, cur_inp->task->target->str()); - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - // rebuild task - task = SearchTask(ComputeDAG(workload_key), workload_key, - cur_inp->task->target, cur_inp->task->target_host, - cur_inp->task->hardware_params); - task_id = task_cache.size(); - - // compute min cost for each task - task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); - min_costs.push_back(cost); - } else { - std::tie(task, task_id) = find_res->second; - min_costs[task_id] = std::min(min_costs[task_id], cost); - } - - tasks.push_back(std::move(task)); - task_ids->push_back(task_id); - // pstates->data.push_back(cur_inp->state); - states.push_back(cur_inp->state); - normalized_throughputs->push_back(cost); - - if (n_lines > 0 && static_cast(states.size()) >= n_lines) { - break; - } - } - - for (size_t i = 0; i < normalized_throughputs->size(); ++i) { - (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; - } - - GetPerStmtFeaturesFromStates(states, tasks, 0, max_n_bufs, features); -} - -void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, - const Array& results, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids) { - Array states; - // ArrayNode* pstates = states.CopyOnWrite(); - std::vector tasks; - - normalized_throughputs->clear(); - task_ids->clear(); - - // (workload_key, target) -> (search_task, task_id) - std::unordered_map, std::pair> task_cache; - // task_id -> min_cost - std::vector min_costs; - - tasks.reserve(inputs.size()); - normalized_throughputs->reserve(inputs.size()); - task_ids->reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - float cost = static_cast(FloatArrayMean(results[i]->costs)); - const std::string& workload_key = inputs[i]->task->workload_key; - SearchTask task; - - size_t task_id; - std::pair key(workload_key, inputs[i]->task->target->str()); - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete - task = inputs[i]->task; - } else { // the measure input is incomplete - // rebuild task for incomplete measure pairs read from file - task = SearchTask(ComputeDAG(workload_key), workload_key, - inputs[i]->task->target, inputs[i]->task->target_host, - inputs[i]->task->hardware_params); - } - task_id = task_cache.size(); - - // compute min cost for each task - task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); - min_costs.push_back(cost); - } else { - std::tie(task, task_id) = find_res->second; - min_costs[task_id] = std::min(min_costs[task_id], cost); - } - - tasks.push_back(std::move(task)); - task_ids->push_back(task_id); - // pstates->data.push_back(inputs[i]->state); - states.push_back(inputs[i]->state); - normalized_throughputs->push_back(cost); - } - - for (size_t i = 0; i < normalized_throughputs->size(); ++i) { - (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; - } - - GetPerStmtFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, - max_n_bufs, features); -} - -TVMByteArray SerializeFeatures(std::vector >&& features, - std::vector&& normalized_throughputs, - std::vector&& task_ids, - std::vector* out_data) { - size_t total_bytes = 0; - std::vector size_vector; - - int n = features.size(); - - // serialize sizes - size_t size_vector_size = 1 + n + 2; - total_bytes += size_vector_size * sizeof(int); - - size_vector.reserve(size_vector_size); - size_vector.push_back(features.size()); - for (const auto& x : features) { - size_vector.push_back(static_cast(x.size())); - total_bytes += sizeof(float) * x.size(); - } - size_vector.push_back(static_cast(normalized_throughputs.size())); - total_bytes += sizeof(float) * normalized_throughputs.size(); - size_vector.push_back(static_cast(task_ids.size())); - total_bytes += sizeof(int) * task_ids.size(); - - CHECK_EQ(size_vector.size(), size_vector_size); - - // allocate memory - out_data->reserve(total_bytes); - char* ptr = out_data->data(); - - // serialize size_vector - memmove(ptr, reinterpret_cast(size_vector.data()), size_vector.size() * sizeof(int)); - ptr += size_vector.size() * sizeof(int); - - // serialize features - for (auto& x : features) { - memmove(ptr, x.data(), sizeof(float) * x.size()); - ptr += sizeof(float) * x.size(); - x.clear(); - } - - // serialize normalized_throughputs - memmove(ptr, reinterpret_cast(normalized_throughputs.data()), - normalized_throughputs.size() * sizeof(int)); - ptr += normalized_throughputs.size() * sizeof(int); - - // serialize task_ids - memmove(ptr, reinterpret_cast(task_ids.data()), task_ids.size() * sizeof(int)); - ptr += task_ids.size() * sizeof(int); - - CHECK_EQ(ptr - out_data->data(), total_bytes); - - return TVMByteArray{out_data->data(), total_bytes}; -} - - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromFile") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string filename = args[0]; - int n_lines = args[1]; - int max_n_bufs = args[2]; - - std::vector > features; - std::vector normalized_throughputs; - std::vector task_ids; - - GetPerStmtFeaturesFromFile(filename, n_lines, max_n_bufs, - &features, &normalized_throughputs, &task_ids); - - // serialization format for n records: - // - // int n; - // int[n+2] sizes - // - // float[sizes[0]] feature for record 1 - // float[sizes[1]] feature for record 2 - // ... feature for record i... - // float[sizes[n-1]] feature for record n - // - // float[sizes[n]] normalized throughput for n records - // int[sizes[n+1]] task id for n records - - std::vector byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); -}); - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromMeasurePairs") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Array inputs = args[0]; - Array results = args[1]; - int skip_first_n_feature_extraction = args[2]; - int max_n_bufs = args[3]; - - std::vector > features; - std::vector normalized_throughputs; - std::vector task_ids; - - GetPerStmtFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, max_n_bufs, - &features, &normalized_throughputs, &task_ids); - - // serialization format for n records: - // - // int n; - // int[n+2] sizes - // - // float[sizes[0]] feature for record 1 - // float[sizes[1]] feature for record 2 - // ... feature for record i... - // float[sizes[n-1]] feature for record n - // - // float[sizes[n]] normalized throughput for n records - // int[sizes[n+1]] task id for n records - - std::vector byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); -}); - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromStates") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Array states = args[0]; - SearchTask task = args[1]; - int max_n_bufs = args[2]; - - std::vector > features; - std::vector normalized_throughputs; - std::vector task_ids; - - GetPerStmtFeaturesFromStates(states, task, 0, max_n_bufs, &features); - - // serialization format for n records: - // - // int n; - // int[n+2] sizes - // - // float[sizes[0]] feature for record 1 - // float[sizes[1]] feature for record 2 - // ... feature for record i... - // float[sizes[n-1]] feature for record n - // - // float[sizes[n]] normalized throughput for n records - // int[sizes[n+1]] task id for n records - - std::vector byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); -}); - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeatureNames") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int max_n_bufs = args[0]; - std::vector names; - - GetPerStmtFeatureName(max_n_bufs, &names); - - Array arr; - for (const auto& x : names) { - arr.push_back(x); - } - *ret = arr; -}); - - -} // namespace ansor -} // namespace tvm diff --git a/src/ansor/feature.h b/src/ansor/feature.h deleted file mode 100644 index e507149643e2..000000000000 --- a/src/ansor/feature.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/feature.h - * \brief Feature extraction for the cost model - */ - -#ifndef TVM_ANSOR_FEATURE_H_ -#define TVM_ANSOR_FEATURE_H_ - -#include -#include -#include "compute_dag.h" -#include "measure.h" - -namespace tvm { -namespace ansor { - -/*! \brief Get PerStmt feature from a tvm IR stmt */ -void GetPerStmtFeature(const Stmt& stmt, - int cache_line_size, - int max_n_bufs, - std::vector* ret); - -/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ -void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret); - - -/*! \brief Get PerStmt feature from states and the same task */ -void GetPerStmtFeaturesFromStates(const Array& states, - const SearchTask& task, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features); - -/*! \brief Get PerStmt feature from states and different tasks */ -void GetPerStmtFeaturesFromStates(const Array& states, - const std::vector& tasks, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features); - -/*! \brief Get PerStmt feature from a log file */ -void GetPerStmtFeaturesFromFile(const std::string& filename, - int n_lines, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids); - -/*! \brief Get PerStmt feature from measure pairs */ -void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, - const Array& results, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids); - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_FEATURE_H_ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 010e5f3dc221..787e4256a181 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -153,28 +153,6 @@ std::vector State::split(int stage_id, const Iterator& it, return DoSplitStep(step); } -std::vector State::follow_split(int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowSplitStep step = FollowSplitStep( - stage_id, GetIndex(stage->iters, it), src_step_id, n_split); - CopyOnWrite()->transform_steps.push_back(step); - return DoFollowSplitStep(step); -} - -std::vector State::follow_fused_split( - int stage_id, const Iterator& it, const std::vector& src_step_ids, - int level, bool factor_or_nparts) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowFusedSplitStep step = - FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); - CopyOnWrite()->transform_steps.push_back(step); - return DoFollowFusedSplitStep(step); -} - Iterator State::fuse(int stage_id, const std::vector& iters) { const Stage& stage = operator->()->stages[stage_id]; std::vector indices; @@ -184,126 +162,6 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { return DoFuseStep(step); } -Iterator State::vectorize(int stage_id, const Iterator& it) { - const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStep( - stage_id, GetIndex(stage->iters, it), kVectorize); - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -Iterator State::parallel(int stage_id, const Iterator& it) { - const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), kParallel); - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { - const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), kUnroll); - - // don't unroll if the extent is larger than max_unroll - if (max_unroll != -1 && it->range.defined()) { - if (auto imm = it->range->extent.as()) { - if (imm->value > max_unroll) { - return it; - } - } - } - - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -void State::compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter) { - const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStep( - stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); - CopyOnWrite()->transform_steps.push_back(step); - return DoComputeAtStep(step); -} - -void State::compute_root(int stage_id) { - ComputeRootStep step = ComputeRootStep(stage_id); - CopyOnWrite()->transform_steps.push_back(step); - return DoComputeRootStep(step); -} - -void State::compute_inline(int stage_id) { - ComputeInlineStep step = ComputeInlineStep(stage_id); - CopyOnWrite()->transform_steps.push_back(step); - return DoComputeInlineStep(step); -} - -Iterator State::bind_thread(int stage_id, const Iterator& it, - IteratorAnnotation thread_type) { - const Stage& stage = operator->()->stages[stage_id]; - if (thread_type < kVThread || thread_type > kThreadY) { - LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " - << "kThreadX, kThreadY"; - } - AnnotationStep step = AnnotationStep( - stage_id, GetIndex(stage->iters, it), thread_type); - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -int State::cache_read(int stage_id, const std::string& scope_name, - const std::vector& reader_stage_ids, - const ComputeDAG& task_dag) { - CacheReadStep step = - CacheReadStep(stage_id, scope_name, reader_stage_ids); - CopyOnWrite()->transform_steps.push_back(step); - return DoCacheReadStep(step, task_dag); -} - -int State::cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { - CacheWriteStep step = CacheWriteStep(stage_id, scope_name); - CopyOnWrite()->transform_steps.push_back(step); - return DoCacheWriteStep(step, task_dag); -} - -void State::pragma(int stage_id, const Iterator& it, - const std::string& pragma_type) { - const Stage& stage = operator->()->stages[stage_id]; - PragmaStep step = - PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); - CopyOnWrite()->transform_steps.push_back(step); - return DoPragmaStep(step); -} - -int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag) { - const Stage& stage = operator->()->stages[stage_id]; - RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), - factor_iter_id); - CopyOnWrite()->transform_steps.push_back(step); - return DoRfactorStep(step, task_dag); -} - -void State::storage_align(int stage_id, const Iterator& it, int factor, - int offset) { - const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStep( - stage_id, GetIndex(stage->iters, it), factor, offset); - CopyOnWrite()->transform_steps.push_back(step); - return DoStorageAlignStep(step); -} - -Iterator State::tensorize(int stage_id, const Iterator& it, - std::string ti_func_name) { - const Stage& stage = operator->()->stages[stage_id]; - TensorizeStep step = TensorizeStep( - stage_id, GetIndex(stage->iters, it), ti_func_name); - CopyOnWrite()->transform_steps.push_back(step); - return DoTensorizeStep(step); -} - // Steps' implementations void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -402,20 +260,6 @@ std::vector State::DoSplitStep(const SplitStep& step) { step->inner_to_outer); } -std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { - std::vector lengths; - step->ExtractSplitLengths(operator->()->transform_steps, &lengths); - return DoSplitStepCommon(step->stage_id, step->iter_id, lengths, true); -} - -std::vector State::DoFollowFusedSplitStep( - const FollowFusedSplitStep& step) { - const PrimExpr& length = - step->ExtractSplitLength(operator->()->transform_steps); - return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, - step->factor_or_nparts); -} - Iterator State::DoFuseStep(const FuseStep& step) { int stage_id = step->stage_id; const Stage& stage = operator->()->stages[stage_id]; @@ -499,292 +343,13 @@ Iterator State::DoFuseStep(const FuseStep& step) { return new_it; } -Iterator State::DoAnnotationStep(const AnnotationStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Iterator it = stage->iters[step->iter_id]; - - CHECK_EQ(it->annotation, IteratorAnnotation::kNone); - Iterator new_it = Iterator(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters, - it->attr); - Stage new_stage = stage; - new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = std::move(new_stage); - return new_it; -} - -void State::DoComputeAtStep(const ComputeAtStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // after compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call - // ComputeDAG::ReplayAndInferBound - std::vector new_iters; - for (const Iterator& it : stage->iters) { - size_t s = it->name.size(); - if (s >= 2 && it->name[s - 2] == '.' && it->name[s - 1] >= '1' && - it->name[s - 1] <= '4') { - // We use a dangerous heuristic rule here : For multi level splitted - // iterators, we assume their length does not change after compute_at. - // Reason: These iterators are generated in MultiStagePolicy by multi - // level tiling, they will be carefully compute_at their consumers. - // In this case, their lengths do not change. - // We do this to keep the AnnotateCPU pass to annotate more efficiently. - new_iters.push_back(it); - } else { - new_iters.push_back(Iterator(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, it->attr)); - } - } - - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = - Stage(stage->op, stage->op_type, std::move(new_iters), kIter, - stage->attrs); - pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, - step->target_iter_id); -} - -void State::DoComputeRootStep(const ComputeRootStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // after compute_root, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call - // ComputeDAG::ReplayAndInferBound - std::vector new_iters; - for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, it->attr)); - } - - // update attach map - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = Stage(stage->op, stage->op_type, - std::move(new_iters), kRoot, - stage->attrs); - pstate->attach_map.DeleteStage(step->stage_id); -} - -void State::DoComputeInlineStep(const ComputeInlineStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - StateNode* pstate = CopyOnWrite(); - - // CHECK the validity of compute_inline - const auto& iter_to_attached_stages = - pstate->attach_map->iter_to_attached_stages; - for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), - 0) - << "Invalid compute_inline: Because there are some other stages " - "that are attached to the target stage"; - } - - pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; - pstate->attach_map.DeleteStage(step->stage_id); -} - -// Common part for steps that add new stages -// (e.g. CacheReadStep, CacheWriteStep, RfactorStep) -void AddStageModificationSteps(size_t step_id, - const std::vector& transform_steps, - std::vector* replay_steps) { - const Step& step = transform_steps[step_id]; - if (step->IsInstance() || - step->IsInstance()) { - replay_steps->push_back(step); - } else if (step->IsInstance()) { - // add FuseStepNode required by rfactor - if (step_id >= 2 && - transform_steps[step_id - 2]->IsInstance()) { - const Step& fuse_step = transform_steps[step_id - 2]; - if (fuse_step->stage_id == step->stage_id) { - replay_steps->push_back(fuse_step); - } - } - // add SplitStepNode required by rfactor - CHECK_GE(step_id, 1); - CHECK(transform_steps[step_id - 1]->IsInstance()); - const Step& split_step = transform_steps[step_id - 1]; - CHECK_EQ(split_step->stage_id, step->stage_id); - replay_steps->push_back(split_step); - // add RfactorStepNode - replay_steps->push_back(step); - } -} - -int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { - StateNode* pstate = CopyOnWrite(); - std::vector replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(step)) { - break; - } - } - dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); - - // target -> target + target_store - // Should update target's op, insert new stage, update the later stage's op - pstate->stages[step->stage_id].CopyOnWrite()->op = - operator->()->task_dag->ops[step->stage_id]; - pstate->stages.insert( - pstate->stages.begin() + step->stage_id + 1, - Stage(operator->()->task_dag->ops[step->stage_id + 1])); - for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { - pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; - } - pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( - step->stage_id + 1, 1); - - return step->stage_id + 1; -} - -int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { - StateNode* pstate = CopyOnWrite(); - std::vector replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(step)) { - break; - } - } - - int last_dag_op_size = pstate->task_dag.defined() ? - pstate->task_dag->ops.size() : dag->ops.size(); - dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); - int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; - CHECK_GE(added_ops, 1); - - // target -> target_compute + target - // Assume target stage has never been applied any steps before cache_write - // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert( - pstate->stages.begin() + step->stage_id, - Stage(operator->()->task_dag->ops[step->stage_id])); - pstate->stages[step->stage_id + 1] = - Stage(operator->()->task_dag->ops[step->stage_id + 1]); - int next_stage_id = step->stage_id + 2; - // Notice: added_ops should actually assert to be 1 - // branch of 2 here is somehow a hack to TVM's cache_write bug with - // multi outputs, see test/cpp/ansor_test.cc: CacheReadWrite test - // for more information - // TODO(jcf94): Fix this - if (added_ops == 2) { - pstate->stages.insert( - pstate->stages.begin() + next_stage_id, - Stage(operator->()->task_dag->ops[next_stage_id])); - next_stage_id++; - } else if (added_ops > 2) { - LOG(ERROR) << "Unexpected behavior of CacheWrite."; - } - for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { - pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; - } - pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( - step->stage_id, added_ops); - - return step->stage_id; -} - -void State::DoPragmaStep(const PragmaStep& step) { - if (step->pragma_type == "debug_skip_region") { - StateNode* pstate = CopyOnWrite(); - pstate->attach_map.DeleteStage(step->stage_id); - } else if (StrStartsWith(step->pragma_type, "auto_unroll_max_step")) { - StateNode* pstate = CopyOnWrite(); - StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - size_t pos = step->pragma_type.find('$'); - stage->attrs.auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); - } else if (step->pragma_type == "tensor_core") { - // Nothing needs to be done here - } else { - LOG(FATAL) << "Invalid pragma: " << step->pragma_type; - } -} - -int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { - StateNode* pstate = CopyOnWrite(); - const auto compute_at_type = pstate->stages[step->stage_id]->compute_at; - std::vector replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(step)) { - break; - } - } - dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); - - // target -> target_compute + target - // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert( - pstate->stages.begin() + step->stage_id, - Stage(operator->()->task_dag->ops[step->stage_id])); - // maintain the compute_at type of target stage - Stage target_stage = - Stage(operator->()->task_dag->ops[step->stage_id + 1]); - target_stage.CopyOnWrite()->compute_at = compute_at_type; - pstate->stages[step->stage_id + 1] = target_stage; - - for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { - pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; - } - pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( - step->stage_id, 1); - - return step->stage_id; -} - -void State::DoStorageAlignStep(const StorageAlignStep& step) { - StateNode* pstate = CopyOnWrite(); - StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - stage->attrs.storage_offset = step->offset; -} - -Iterator State::DoTensorizeStep(const TensorizeStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Iterator it = stage->iters[step->iter_id]; - Iterator new_it = Iterator(it->name, it->range, it->iter_type, - IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); - Stage new_stage = stage; - new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = std::move(new_stage); - return new_it; -} - void State::DoStep(const Step& step, const ComputeDAG& dag) { if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); } else if (auto ps = step.as()) { DoSplitStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoFollowSplitStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoFollowFusedSplitStep(GetRef(ps)); } else if (auto ps = step.as()) { DoFuseStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoAnnotationStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoComputeAtStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoComputeRootStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoComputeInlineStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoCacheReadStep(GetRef(ps), dag); - } else if (auto ps = step.as()) { - DoCacheWriteStep(GetRef(ps), dag); - } else if (auto ps = step.as()) { - DoPragmaStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoRfactorStep(GetRef(ps), dag); - } else if (auto ps = step.as()) { - DoStorageAlignStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoTensorizeStep(GetRef(ps)); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -1068,26 +633,6 @@ TVM_REGISTER_GLOBAL("ansor.StateSplit") return Array{state, Array(res)}; }); -TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - std::vector array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = state.follow_fused_split( - stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; -}); - TVM_REGISTER_GLOBAL("ansor.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { @@ -1099,100 +644,6 @@ TVM_REGISTER_GLOBAL("ansor.StateFuse") return Array{state, res}; }); -TVM_REGISTER_GLOBAL("ansor.StateVectorize") -.set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.vectorize(stage_id, it); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateParallel") -.set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.parallel(stage_id, it); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateUnroll") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int max_unroll) { - const auto& res = state.unroll(stage_id, it, max_unroll); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateBindThread") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int thread_type) { - const auto& res = - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateComputeAt") -.set_body_typed([](State state, int stage_id, int target_stage_id, - const Iterator& target_iter) { - state.compute_at(stage_id, target_stage_id, target_iter); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") -.set_body_typed([](State state, int stage_id) { - state.compute_root(stage_id); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateComputeInline") -.set_body_typed([](State state, int stage_id) { - state.compute_inline(stage_id); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateCacheRead") -.set_body_typed([](State state, int stage_id, const std::string& scope_name, - const Array& reader_stage_ids, - const ComputeDAG& task_dag) { - std::vector array_reader_stage_ids; - for (const auto& i : reader_stage_ids) { - array_reader_stage_ids.push_back(i->value); - } - int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, - task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") -.set_body_typed([](State state, int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { - int res = state.cache_write(stage_id, scope_name, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StatePragma") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const std::string& pragma_type) { - state.pragma(stage_id, it, pragma_type); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateRfactor") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int factor_iter_id, const ComputeDAG& task_dag) { - int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int factor, int offset) { - state.storage_align(stage_id, it, factor, offset); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateTensorize") -.set_body_typed([](State state, int stage_id, const Iterator& it, - std::string ti_func) { - const auto& res = state.tensorize(stage_id, it, ti_func); - return Array{state, res}; -}); - TVM_REGISTER_GLOBAL("ansor.StateEqual") .set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 1b7bbc40bb31..2d6c85db0247 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -220,13 +220,7 @@ class StepNode: public Object { TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); // Step forward decelerations -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class TensorizeStep; +class ReorderStep; class SplitStep; class FuseStep; /*! \brief A state in the search process. * It consists of the current loop structure and the history steps to reach this state. */ @@ -264,55 +258,18 @@ class State : public ObjectRef { // Schedule primitives void reorder(int stage_id, const std::vector& order); - void compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter); - void compute_root(int stage_id); - void compute_inline(int stage_id); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); std::vector split(int stage_id, const Iterator& it, const std::vector& lengths, bool inner_to_outer = true); - std::vector follow_split(int stage_id, const Iterator& it, - int src_step_id, int n_split); - std::vector follow_fused_split(int stage_id, const Iterator& it, - const std::vector& src_step_ids, - int level, bool factor_or_nparts); Iterator fuse(int stage_id, const std::vector& iters); - Iterator vectorize(int stage_id, const Iterator& it); - Iterator parallel(int stage_id, const Iterator& it); - Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); - Iterator bind_thread(int stage_id, const Iterator& it, - IteratorAnnotation thread_type); - Iterator tensorize(int stage_id, const Iterator& it, - std::string ti_func_name); - int cache_read(int stage_id, const std::string& scope_name, - const std::vector& reader_stage_ids, - const ComputeDAG& task_dag); - int cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag); - int rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag); /* Do transform steps * Note: The following functions only change loop state but do not change transform_history. * We separate these functions out, * so you can call them for replay easily given history steps */ void DoReorderStep(const ReorderStep& step); - void DoComputeAtStep(const ComputeAtStep& step); - void DoComputeRootStep(const ComputeRootStep& step); - void DoComputeInlineStep(const ComputeInlineStep& step); - void DoPragmaStep(const PragmaStep& step); - void DoStorageAlignStep(const StorageAlignStep& step); std::vector DoSplitStep(const SplitStep& step); - std::vector DoFollowSplitStep(const FollowSplitStep& step); - std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); - Iterator DoAnnotationStep(const AnnotationStep& step); - Iterator DoTensorizeStep(const TensorizeStep& step); - int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); - int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); - int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index e99f41725077..c50191813b2e 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -40,9 +40,7 @@ TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); TVM_REGISTER_OBJECT_TYPE(RunnerNode); TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); -TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); -TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); const char* ErrorNoToStr[] = { "NoError", @@ -127,38 +125,6 @@ Array LocalBuilderNode::Build(const Array& inputs, return Array(); } -// RPC Runner -RPCRunner::RPCRunner(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { - auto node = make_object(); - node->key = key; - node->host = host; - node->port = port; - node->priority = priority; - node->timeout = timeout; - node->n_parallel = n_parallel; - node->number = number; - node->repeat = repeat; - node->min_repeat_ms = min_repeat_ms; - node->cooldown_interval = cooldown_interval; - data_ = std::move(node); -} - -Array RPCRunnerNode::Run(const Array& inputs, - const Array& build_results, - int verbose) { - if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { - Array results = (*f)( - inputs, build_results, key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval, verbose); - return results; - } else { - LOG(FATAL) << "ansor.rpc_runner.run is not registered"; - } - return Array(); -} - // Local Runner LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval) { @@ -379,14 +345,6 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner") return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); }); -TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed([](const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval){ - return RPCRunner(key, host, port, priority, timeout, n_parallel, number, - repeat, min_repeat_ms, cooldown_interval); -}); - TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") .set_body_typed([](Builder builder, Runner runner, Array callbacks, int verbose, diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 760a1542944f..630365512eb6 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -219,42 +219,6 @@ class LocalBuilder: public Builder { TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); }; -/*! \brief RPCRunner that uses RPC call to measures the time cost of programs - * on remote devices */ -class RPCRunnerNode : public RunnerNode { - public: - std::string key; - std::string host; - int port; - int priority; - int n_parallel; - int number; - int repeat; - int min_repeat_ms; - double cooldown_interval; - - /*! \biref Run measurement and return results */ - Array Run(const Array& inputs, - const Array& build_results, - int verbose) final; - - static constexpr const char* _type_key = "ansor.RPCRunner"; - TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); -}; - -/*! - * \brief Managed reference to RPCRunnerNode. - * \sa RPCRunnerNode - */ -class RPCRunner : public Runner { - public: - RPCRunner(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, Runner, RPCRunnerNode); -}; - /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc new file mode 100644 index 000000000000..ba861f333c78 --- /dev/null +++ b/src/ansor/search_policy/empty_policy.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "empty_policy.h" + +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); + +State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) { + cur_task = task; + + // Run pre_search_callbacks before the search process + // This Interface is usually used to set some init status + RunCallbacks(pre_search_callbacks); + + if (n_trials <= 1) { + const auto& res = SearchOneRound(); + CHECK_GT(res.size(), 0); + return res[0]; + } else { + std::vector inputs; + std::vector results; + + measurer->Reset(); + int ct = 0; + // In each round, we call SearchOneRound to get several candidate states, + // then use ProgramMeasurer to test their performance + while (ct < n_trials) { + const auto& res = SearchOneRound(); + ct += res.size(); + inputs.clear(); + for (const auto& state : res) { + inputs.emplace_back(cur_task, state); + } + measurer->Measure(cur_task, GetRef(this), inputs, &results); + } + + // Return a state with best measured performance + return measurer->best_state[cur_task->workload_key]; + } +} + +std::pair, Array > EmptyPolicyNode::ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { + // The whole process is almost the same as Search, while this function is designed to be + // called and managed by another global task scheduler + + std::vector inputs; + std::vector results; + + const auto& res = SearchOneRound(); + for (const auto& state : res) { + inputs.emplace_back(cur_task, state); + } + measurer->Measure(cur_task, GetRef(this), inputs, &results); + + // Return a pair of MeasureInput Array and MeasureResult Array + Array inputs_arr(std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); + Array results_arr(std::make_move_iterator(results.begin()), + std::make_move_iterator(results.end())); + return std::make_pair(std::move(inputs_arr), std::move(results_arr)); +} + +std::vector EmptyPolicyNode::SearchOneRound() { + std::vector res; + res.push_back(cur_task->compute_dag.GetInitState()); + // As an example policy, EmptyPolicy always return a init state + return res; +} + +TVM_REGISTER_GLOBAL("ansor.EmptyPolicy") +.set_body_typed([]() { return EmptyPolicy(make_object()); }); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h new file mode 100644 index 000000000000..5c2f52608fe0 --- /dev/null +++ b/src/ansor/search_policy/empty_policy.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/empty_policy.h + * \brief This is an basic example of search policy + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ + +#include +#include + +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +/*! + * \file ansor/search_policy/empty_policy.h + * \brief This is an basic example for search policy. The EmptyPolicy will + * always generates the init state of a ComputeDAG. + */ +class EmptyPolicyNode : public SearchPolicyNode { + public: + /*! \brief Search and make n_trails measurements. + * \returns the best state + */ + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) final; + + /*! \brief Continue search for one round. This is used by JointTuner + * \returns the measurement pairs + */ + std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + + static constexpr const char *_type_key = "ansor.EmptyPolicy"; + TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); + + private: + /*! + * \brief Usually we need a sub function to generate several candidate states in each + * search round. + * \returns Several generated states + */ + std::vector SearchOneRound(); +}; + +/*! + * \brief Managed reference to EmptyPolicyNode. + * \sa EmptyPolicyNode + */ +class EmptyPolicy : public SearchPolicy { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ \ No newline at end of file diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index b86bf9490851..e7a12702ba70 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -29,51 +29,8 @@ namespace tvm { namespace ansor { +TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); - -void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { - LogReader reader = LogReader(log_file); - const auto& res = reader->ReadLines(-1); - size_t log_size = res.first.size(); - CHECK_EQ(log_size, res.second.size()); - if (log_size) { - std::vector measured_states; - std::vector measured_throughputs; - for (size_t i = 0; i < log_size; i++) { - const auto& inp = res.first[i]; - if (inp->task->workload_key == cur_task->workload_key && - inp->task->target->target_name.compare( - cur_task->target->target_name) == 0) { - State state = cur_task->compute_dag.GetInitState(); - state.CopyOnWrite()->transform_steps = inp->state->transform_steps; - state.DoSteps(inp->state->transform_steps, cur_task->compute_dag); - measured_states.emplace_back(std::move(state)); - measured_throughputs.push_back(res.second[i]->error_no == 0 ? - (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); - } - } - cur_task->compute_dag.InferBound(&measured_states); - for (size_t i = 0; i < measured_states.size(); i ++) { - auto& state = measured_states[i]; - const auto& state_str = state.ToStr(); - if (!measured_states_set_.count(state_str)) { - measured_states_set_.insert(state_str); - if (measured_throughputs[i] != 0.0) { - measured_states_vector_.emplace_back(std::move(state)); - measured_states_throughputs_.emplace_back(measured_throughputs[i]); - } - } - } - - StdCout(verbose) << "Successfully load " << measured_states_set_.size() - << " measurement records from " << log_file - << " for " << cur_task->workload_key << std::endl; - } else { - StdCout(verbose) << "No measurement records found in " - << log_file << " for " << cur_task->workload_key << std::endl; - } -} void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { @@ -83,16 +40,6 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } } -PreloadMeasuredStates::PreloadMeasuredStates(std::string filename) { - auto node = make_object(); - node->filename = std::move(filename); - data_ = std::move(node); -} - -void PreloadMeasuredStatesNode::callback(SearchPolicyNode* policy) { - policy->PreloadMeasuredStates(filename); -} - // Search Policy TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") .set_body_typed([](SearchPolicy policy, SearchTask task, int num_measure, @@ -118,10 +65,5 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") policy->verbose = verbose; }); -TVM_REGISTER_GLOBAL("ansor.PreloadMeasuredStates") -.set_body_typed([](std::string filename) { - return PreloadMeasuredStates(filename); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 03e7c3f025df..eb4703be1914 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -48,30 +48,6 @@ class SearchCallbackNode : public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -/*! \brief Preload measured states from a log file. - * This can resume the state of the search policy */ -class PreloadMeasuredStatesNode : public SearchCallbackNode { - public: - std::string filename; - - void callback(SearchPolicyNode* policy) final; - - static constexpr const char *_type_key = "ansor.PreloadMeasuredStates"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); -}; - -/*! - * \brief Managed reference to PreloadMeasuredStatesNode. - * \sa PreloadMeasuredStatesNode - */ -class PreloadMeasuredStates : public SearchCallback { - public: - explicit PreloadMeasuredStates(std::string filename); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback, - PreloadMeasuredStatesNode); -}; - /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: @@ -94,23 +70,9 @@ class SearchPolicyNode : public Object { virtual std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; - // Preload measured states from a log file to resume the state of the search policy - void PreloadMeasuredStates(const std::string& log_file); - // Run a list of callback functions void RunCallbacks(const Array& callbacks); - // Dict keys to give hints to the policy - static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; - static constexpr const char* always_unroll_key = "ansor_always_unroll"; - static constexpr const char* no_split_at_inner_key = "ansor_no_split_at_inner"; - static constexpr const char* no_split_at_outer_key = "ansor_no_split_at_outer"; - static constexpr const char* last_split_is_one_key = "ansor_last_split_is_one"; - // Flag keys to give hints to the policy - static constexpr const char* always_compute_inline_key = "ansor_always_compute_inline"; - static constexpr const char* no_cache_write_key = "ansor_no_cache_write"; - static constexpr const char* no_cache_read_key = "ansor_no_cache_read"; - static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc deleted file mode 100644 index c4365a391865..000000000000 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ /dev/null @@ -1,1541 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/sketch_search_policy.h - * \brief The search policy that searches in a hierarchical search space defined by sketches. - * The policy randomly samples programs from the space defined by sketches - * and use evolutionary search to fine-tune them. - */ - -#include "sketch_search_policy.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "utils.h" - -#define IS_GPU(task) ((task)->target->device_type == kDLGPU || \ - (task)->target->device_type == kDLOpenCL) - -namespace tvm { -namespace ansor { - -TVM_REGISTER_NODE_TYPE(SketchSearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); - -// All possible candidates for auto_unroll -const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; - -SketchSearchPolicy::SketchSearchPolicy(CostModel program_cost_model, - Map params, - int seed) { - auto node = make_object(); - node->program_cost_model = std::move(program_cost_model); - node->rand_gen_ = std::mt19937(seed); - node->params = std::move(params); - data_ = std::move(node); -} - -State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, int verbose, - ProgramMeasurer measurer, Array pre_search_callbacks) { - std::vector best_states, random_states; - this->cur_task = task; - this->verbose = verbose; - num_measure_per_iter_ = num_measure_per_iter; - - PrintTitle("Call search callbacks", verbose); - RunCallbacks(pre_search_callbacks); - - if (n_trials <= 1) { // no measurement is allowed - SearchOneRound(&best_states, 0, &random_states); - CHECK_GT(best_states.size(), 0); - return best_states[0]; - } else { - std::vector inputs; - std::vector results; - int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter); - - measurer->Reset(); - - early_stopping = early_stopping < 0 ? std::numeric_limits::max() >> 1 : early_stopping; - - int ct = 0; - while (ct < n_trials) { - if (!inputs.empty()) { - // retrain cost models - PrintTitle("Train cost model", verbose); - program_cost_model->Update(inputs, results); - } - - // Search one round to get promising states - PrintTitle("Search", verbose); - SearchOneRound(&best_states, num_random, &random_states); - - // Infer bound. This is necessary for computing the correct ToStr() for redundancy check - cur_task->compute_dag.InferBound(&best_states); - cur_task->compute_dag.InferBound(&random_states); - - // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state - // Also pick some random states to do eps-greedy - PickStatesWithEpsGreedy(&inputs, best_states, random_states, n_trials - ct); - - // Have traversed all of search space - if (inputs.empty()) { - StdCout(verbose) << "All candidates in the search space have been measured." << std::endl; - break; - } - - // Measure candidate states - PrintTitle("Measure", verbose); - measurer->Measure(cur_task, GetRef(this), inputs, &results); - ct += inputs.size(); - - if (ct - measurer->best_ct[cur_task->workload_key] > early_stopping) { - StdCout(verbose) << "Meet the early stopping condition." << std::endl; - break; - } - - // Update measured states. These states will join the LocalMutation in later rounds - for (const auto& res : results) { - measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); - } - } - PrintTitle("Done", verbose); - - return measurer->best_state[cur_task->workload_key]; - } -} - -std::pair, Array > - SketchSearchPolicyNode::ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { - if (cur_task.defined()) { - CHECK_EQ(cur_task, task); - } else { - cur_task = task; - } - this->verbose = verbose; - num_measure_per_iter_ = num_measure; - - std::vector best_states, random_states; - std::vector inputs; - std::vector results; - int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure); - - // Search one round to get promising states - PrintTitle("Search", verbose); - SearchOneRound(&best_states, num_random * 2, &random_states); - - // Fill correct bound. This is necessary for computing the correct ToStr() for reduncency check - cur_task->compute_dag.InferBound(&best_states); - cur_task->compute_dag.InferBound(&random_states); - - // Pick `num_measure` states to measure, check hash to remove already measured state - // Also pick some random states to do eps-greedy - PickStatesWithEpsGreedy(&inputs, best_states, random_states, num_measure); - - // Measure candidate states - PrintTitle("Measure", verbose); - measurer->Measure(cur_task, GetRef(this), inputs, &results); - - // Update throughputs of measured states. These states will join the LocalMutation in later rounds - for (const auto& res : results) { - measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); - } - - // Update the cost model - Array inputs_arr(std::make_move_iterator(inputs.begin()), - std::make_move_iterator(inputs.end())); - Array results_arr(std::make_move_iterator(results.begin()), - std::make_move_iterator(results.end())); - - PrintTitle("Train cost model", verbose); - program_cost_model->Update(inputs_arr, results_arr); - return std::make_pair(std::move(inputs_arr), std::move(results_arr)); -} - -void SketchSearchPolicyNode::PickStatesWithEpsGreedy( - std::vector* inputs, - const std::vector& best_states, - const std::vector& random_states, - int remaining_n_trials) { - int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter_); - int num_good = num_measure_per_iter_ - num_random; - - inputs->clear(); - size_t offset_best = 0, offset_random = 0; - - while (static_cast(inputs->size()) < std::min(num_measure_per_iter_, remaining_n_trials)) { - const State* pstate; - - bool has_best = offset_best < best_states.size(); - bool has_random = offset_random < random_states.size(); - - if (static_cast(inputs->size()) < num_good) { - // prefer best states - if (has_best) { - pstate = &best_states[offset_best++]; - } else if (has_random) { - pstate = &random_states[offset_random++]; - } else { - break; - } - } else { - // prefer random states - if (has_random) { - pstate = &random_states[offset_random++]; - } else if (has_best) { - pstate = &best_states[offset_best++]; - } else { - break; - } - } - - // Check if it has already been measured - std::string state_str = pstate->ToStr(); - - if (measured_states_set_.count(state_str)) { continue; } - measured_states_set_.insert(std::move(state_str)); - - inputs->push_back(MeasureInput(cur_task, *pstate)); - measured_states_vector_.push_back(*pstate); - } -} - -void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, - int num_random_states, std::vector* random_states) { - best_states->clear(); - random_states->clear(); - - // Get parameters - int population = GetIntParam(params, "evolutionary_search_population"); - int num_use_measured = std::min(static_cast(measured_states_vector_.size()), - static_cast( - GetDoubleParam(params, "evolutionary_search_use_measured_ratio") * population)); - bool have_cost_model = !program_cost_model->IsInstance(); - - if (!have_cost_model) { - num_use_measured = 0; - } - - // Generate sketches - std::vector sketches; - GenerateSketch(&sketches); - - // PrintAllStates(sketches); - // exit(0); - - // Sample the init population - std::vector init_population; - SampleInitPopulation(sketches, population - num_use_measured, &init_population); - - // PrintAllStates(init_population); - // exit(0); - - if (have_cost_model) { - // Also insert already measured good states to the initial population - std::vector indices; - Argsort(measured_states_throughputs_, &indices); - for (int i = 0; i < num_use_measured; i++) { - init_population.push_back(measured_states_vector_[indices[i]]); - } - - // Perform evolutionary search - EvolutionarySearch(init_population, num_measure_per_iter_ * 2, best_states); - } else { - // If the cost model is useless (i.e. RandomCostModel), skip evolutionary search - RandomSampleStates(init_population, &rand_gen_, num_measure_per_iter_ * 3, best_states); - } - - // Sample some random states for eps-greedy - RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); -} - -// The base class for derivation rules used in sketch generation -class SketchGenerationRule { - public: - enum ConditionEnum { - kPass, kApply, kApplyAndSkipRest - }; - - virtual ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) = 0; - virtual std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) = 0; -}; - -static inline bool ShouldBeCacheRead( - const SketchSearchPolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - if (HasAttrsFlag(state, stage_id, - SearchPolicyNode::no_cache_read_key)) { - return false; - } - - std::unordered_set consumers; - GetConsumers(task, state, stage->op, &consumers); - if (consumers.size() != 1) { - return false; - } - - int target_stage_id = OperationToStage(*consumers.begin(), state); - if (!NeedsMultilevelTiling(task, state, - state->stages[target_stage_id]->op)) { - return false; - } - - std::unordered_set producers; - GetProducers(task, state, state->stages[target_stage_id]->op, &producers); - // Only those directly mapped stages can do CacheRead - if (producers.find(stage->op) == producers.end()) { - return false; - } - - return true; -} - -static inline bool ShouldAlwaysBeInlined( - const SketchSearchPolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - if (stage->op->IsInstance()) { - return false; - } - - // Inline limitation of TVM - if (!IsOutputOp(task, state, stage->op) && !HasReduceIter(stage)) { - // Always inline condition: - // 1. Has attrs that this must be inlined - // 2. Analyse shows this is strict inlineable - // 3. A GPU stage can be inlined(If it should be cache read, do it first) - if (HasAttrsFlag(state, stage_id, - SearchPolicyNode::always_compute_inline_key) || - IsStrictInlineable(task, state, stage->op) || - (IS_GPU(policy->cur_task) && - !ShouldBeCacheRead(policy, state, stage_id))) { - return true; - } - } - - return false; -} - -// The rule that inlines simple elementwise ops -class RuleAlwaysInline : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - return ShouldAlwaysBeInlined(policy, state, stage_id) ? - kApplyAndSkipRest : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - State tmp_s = state; - tmp_s.compute_inline(stage_id); - return {std::make_pair(std::move(tmp_s), stage_id - 1)}; - } -}; - -// The rule that simply skip the current stage -class RuleSkipStage : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - const auto& attrs = stage->op->attrs; - if ((attrs.count(SearchPolicyNode::no_split_at_inner_key) || - attrs.count(SearchPolicyNode::no_split_at_outer_key)) && - NeedsMultilevelTiling(task, state, stage->op)) { - // for the transform stages in Winograd - return kPass; - } - - return kApply; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - return {std::make_pair(state, stage_id - 1)}; - } -}; - -// The rule that performs multi-level tiling -class RuleMultiLevelTiling : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - return NeedsMultilevelTiling(task, state, stage->op) ? - (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? - GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : - GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); - - std::vector spatial_split_step_ids; - State tmp_s = state; - tmp_s = DoMultiLevelTiling(tmp_s, stage_id, multi_level_tiling_structure, - &spatial_split_step_ids); - return {std::make_pair(std::move(tmp_s), stage_id-1)}; - } -}; - -// The rule that performs multi-level tiling and fuses later consumers -class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - int target_stage_id; - - if (IS_GPU(policy->cur_task)) { - return NeedsMultilevelTiling(task, state, stage->op) && - HasSingleElementwiseMatchedConsumer(task, state, stage, - &target_stage_id) && - (!HasCacheReadStage(state, stage_id) || - HasCacheWriteStage(state, stage_id)) ? - kApplyAndSkipRest : kPass; - } - - return NeedsMultilevelTiling(task, state, stage->op) && - HasSingleElementwiseMatchedConsumer(task, state, stage, - &target_stage_id) ? - kApply : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? - GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : - GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); - - std::vector spatial_split_step_ids; - int target_stage_id; - std::unordered_set consumers; - - GetConsumers(task, state, state->stages[stage_id]->op, &consumers); - CHECK(HasSingleElementwiseMatchedConsumer(task, state, stage, &target_stage_id)); - - State base_state = state; - base_state = DoMultiLevelTiling(base_state, stage_id, - multi_level_tiling_structure, &spatial_split_step_ids); - std::vector follow_tiling_levels; - if (IS_GPU(policy->cur_task)) { - follow_tiling_levels.push_back(3); - } else { - follow_tiling_levels.push_back(1); - follow_tiling_levels.push_back(2); - } - - std::vector > ret; - for (int level : follow_tiling_levels) { - if (tolower(multi_level_tiling_structure[level-1]) != 's') { - continue; - } - State tmp_s = base_state; - tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level); - const Iterator &target_iter = tmp_s->stages[target_stage_id]->iters[ - level * spatial_split_step_ids.size() - 1]; - tmp_s.compute_at(stage_id, target_stage_id, target_iter); - - ret.emplace_back(std::move(tmp_s), stage_id - 1); - } - - return ret; - } -}; - -// The rule that adds a cache write stage -class RuleAddCacheWrite : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - int target_stage_id; - - // Add cache write if a stage needs multi-level tiling, - // but does not have a element-wise matched consumer - return NeedsMultilevelTiling(task, state, stage->op) && - !HasAttrsFlag(state, stage_id, SearchPolicyNode::no_cache_write_key) && - (!HasSingleElementwiseMatchedConsumer(task, state, stage, - &target_stage_id) || - (HasCacheReadStage(state, stage_id) && - !HasCacheWriteStage(state, stage_id))) ? - kApply : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - - State tmp_s = state; - tmp_s.cache_write(stage_id, "local", task->compute_dag); - return {std::make_pair(std::move(tmp_s), stage_id)}; - } -}; - -// The rule that adds a cache read stage -// Mainly used for GPU cooperative fetching -// Currently only support 1 to 1 match cache read -class RuleAddCacheRead : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - return ShouldBeCacheRead(policy, state, stage_id) ? - kApplyAndSkipRest : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - std::unordered_set consumers; - GetConsumers(task, state, stage->op, &consumers); - CHECK_EQ(consumers.size(), 1); - int target_stage_id = OperationToStage(*consumers.begin(), state); - State tmp_s = state; - int added_stage_id = tmp_s.cache_read(stage_id, "shared", - {target_stage_id}, - task->compute_dag); - target_stage_id++; - const auto& share_read_pos = GetLastReduceIteratorInOutermostReduceTile( - tmp_s->stages[target_stage_id]); - tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos); - - return {std::make_pair(std::move(tmp_s), stage_id)}; - } -}; - -// The rule that adds rfactor stage -class RuleAddRfactor : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - return NeedsRfactor(task, state, stage->op) && - !HasCacheWriteStage(state, stage_id) ? - kApply : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - std::vector > ret; - - State tmp_s = state; - - // fuse reduce iters - std::vector space_iters, reduce_iters; - for (const auto &iter : stage->iters) { - if (iter->iter_type == kSpace) { - space_iters.push_back(iter); - } else if (iter->iter_type == kReduce) { - reduce_iters.push_back(iter); - } - } - CHECK(!reduce_iters.empty()); - Iterator fused_reduce_iter; - if (reduce_iters.size() > 1) { - fused_reduce_iter = tmp_s.fuse(stage_id, reduce_iters); - } else { - fused_reduce_iter = reduce_iters[0]; - } - - // split reduce iters - const auto &split_res = tmp_s.split(stage_id, fused_reduce_iter, {1}); - int factor_axis_id = static_cast(space_iters.size()); - State base_state = tmp_s; - for (const auto &split_iter : split_res) { - tmp_s = base_state; - tmp_s.rfactor(stage_id, split_iter, factor_axis_id, task->compute_dag); - - // reorder the space iterator to innermost for vectorization - if (split_iter == split_res[1]) { - std::vector new_order; - for (size_t i = 0; i < tmp_s->stages[stage_id]->iters.size(); ++i) { - if (i != space_iters.size()) { - new_order.push_back(tmp_s->stages[stage_id]->iters[i]); - } - } - new_order.push_back(tmp_s->stages[stage_id]->iters[space_iters.size()]); - tmp_s.reorder(stage_id, new_order); - } - ret.emplace_back(std::move(tmp_s), stage_id - 1); - } - - return ret; - } -}; - -void SketchSearchPolicyNode::GenerateSketch( - std::vector* out_states) { - State init_state = cur_task->compute_dag.GetInitState(); - std::string cpu_multi_level_tiling_structure = - GetStringParam(params, "cpu_multi_level_tiling_structure"); - - // two ping pong buffers to avoid copy - std::vector states_buf1, states_buf2; - std::vector *pnow, *pnext; - pnow = &states_buf1; - pnext = &states_buf2; - pnow->push_back(init_state); - - // A map that maps state to its current working position (stage_id) - std::unordered_map cur_stage_id_map; - cur_stage_id_map[init_state] = static_cast(init_state->stages.size() - 1); - - static RuleSkipStage rule_skip_stage; - static RuleAlwaysInline rule_always_inline; - static RuleMultiLevelTiling rule_multi_level_tiling; - static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion; - static RuleAddCacheWrite rule_add_cache_write_stage; - static RuleAddCacheRead rule_add_cache_read_stage; - static RuleAddRfactor rule_add_rfactor; - if (sketch_rules.empty()) { - // We may apply and skip the rest when processing some rules, - // should take care of the rule vector order here - sketch_rules.push_back(&rule_always_inline); - sketch_rules.push_back(&rule_add_cache_write_stage); - sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); - sketch_rules.push_back(&rule_multi_level_tiling); - sketch_rules.push_back(&rule_add_rfactor); - sketch_rules.push_back(&rule_skip_stage); - if (IS_GPU(cur_task)) { - // Try cache read first before cache write - sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage); - } - // TODO(xian): Add a new rule to try combination of multi-level - // tiling + rfactor - } - - // Derivation rule based synthesizer - while (!pnow->empty()) { - pnext->clear(); - - for (const State& state : *pnow) { - int stage_id = cur_stage_id_map[state]; - - // Reaches to the terminal stage - if (stage_id < 0) { - out_states->push_back(state); - continue; - } - - // Try all derivation rules - for (const auto& rule : sketch_rules) { - auto rule_check = rule->MeetCondition(this, state, stage_id); - if (rule_check > SketchGenerationRule::ConditionEnum::kPass) { - for (const auto& pair : rule->Apply(this, state, stage_id)) { - cur_stage_id_map[pair.first] = pair.second; - pnext->push_back(pair.first); - } - // Skip the reset rules - if (rule_check == SketchGenerationRule::ConditionEnum::kApplyAndSkipRest) { - break; - } - } - } - } - - std::swap(pnow, pnext); - } - - // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(), - // so later we can sample random value for the split factor. - // Why don't we use Expr() when doing the split for rfactor at the first time? - // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM. - // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor - // in other stages - for (size_t i = 0; i < out_states->size(); ++i) { - auto pstate = (*out_states)[i].CopyOnWrite(); - for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) { - if (pstate->transform_steps[step_id]->IsInstance()) { - CHECK_GE(step_id, 1); - int split_step_id = step_id - 1; - auto step = pstate->transform_steps[split_step_id].as(); - CHECK(step != nullptr); - pstate->transform_steps[split_step_id] - = SplitStep(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, - step->inner_to_outer); - } - } - } - - StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states->size() << std::endl; -} - -int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen, - SplitFactorizationMemo* split_memo) { - for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { - if (auto ps = (*state)->transform_steps[step_id].as()) { - bool defined = true; - for (const PrimExpr& len : ps->lengths) { - if (!len.defined()) { - defined = false; - } - } - - if (defined) { - continue; - } - - int extent = GetIntImm(ps->extent); - const std::vector >& candidate_lens = - split_memo->GetFactorizationSchemes( - extent, ps->lengths.size(), - policy->cur_task->hardware_params->max_innermost_split_factor); - - StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps[step_id] = SplitStep( - ps->stage_id, ps->iter_id, ps->extent, - candidate_lens[(*rand_gen)() % candidate_lens.size()], - ps->inner_to_outer); - } - } - - return 0; -} - -int InitPopulationThreadBind(const SketchSearchPolicyNode* policy, - State* state) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - auto pop = stage->op.as(); - - if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { - continue; - } - - if (HasAnnotationIter(stage, IteratorAnnotation::kThreadX)) { - // Skip if this stage has already done thread bind - continue; - } - - std::vector to_fuse; - - // This stage has not been tiled, but in GPU schedule, we must tile it - // to do thread binding - if (!HasSplitStep(*state, stage_id)) { - for (const auto& it : (*state)->stages[stage_id]->iters) { - if (it->iter_type == kReduce) { - break; - } - to_fuse.push_back(it); - } - const auto& fused_it = state->fuse(stage_id, to_fuse); - // Set default vthread=1 & threadIdx.x=default_warp_size - // EvolutionarySearch will try more possiblity - if (GetExtent(fused_it) <= - policy->cur_task->hardware_params->warp_size) { - state->bind_thread(stage_id, fused_it, kThreadX); - } else { - const auto& split_its = state->split(stage_id, fused_it, - {1, policy->cur_task->hardware_params->warp_size}); - state->bind_thread(stage_id, split_its[0], kBlockX); - state->bind_thread(stage_id, split_its[1], kVThread); - state->bind_thread(stage_id, split_its[2], kThreadX); - } - - continue; - } - - int total_space_extent = 1; - for (const auto& i : pop->root_iter_vars()) { - CHECK(i->dom.defined()); - const auto& pint = i->dom->extent.as(); - CHECK(pint); - total_space_extent *= pint->value; - } - - // TODO(..): Add ThreadBind support for rfactor - if (total_space_extent <= policy->cur_task->hardware_params->warp_size) { - for (const auto& it : (*state)->stages[stage_id]->iters) { - if (it->iter_type == kReduce) { - break; - } - to_fuse.push_back(it); - } - const auto& fused_it = state->fuse(stage_id, to_fuse); - state->bind_thread(stage_id, fused_it, kThreadX); - - continue; - } - - // Fuse the outermost space tile as blockIdx - for (size_t i = 0; i < pop->axis.size(); i++) { - const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StrEndsWith(it->name, ".0")) { - break; - } - to_fuse.push_back(it); - } - const auto& blockidx_it = state->fuse(stage_id, to_fuse); - state->bind_thread(stage_id, blockidx_it, kBlockX); - - // Fuse the second outermost space tile as vthread - to_fuse.clear(); - for (size_t i = 1; i < pop->axis.size() + 1; i++) { - const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StrEndsWith(it->name, ".1")) { - break; - } - to_fuse.push_back((*state)->stages[stage_id]->iters[i]); - } - const auto& vthread_it = state->fuse(stage_id, to_fuse); - if (GetExtent(vthread_it) > - policy->cur_task->hardware_params->max_vthread_extent) { - return -1; - } - state->bind_thread(stage_id, vthread_it, kVThread); - - // Fuse the third outermost space tile as threadIdx - to_fuse.clear(); - for (size_t i = 2; i < pop->axis.size() + 2; i++) { - const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StrEndsWith(it->name, ".2")) { - break; - } - to_fuse.push_back((*state)->stages[stage_id]->iters[i]); - } - const auto& threadidx_it = state->fuse(stage_id, to_fuse); - if (GetExtent(threadidx_it) < - policy->cur_task->hardware_params->warp_size) { - return -1; - } - state->bind_thread(stage_id, threadidx_it, kThreadX); - } - - return 0; -} - -int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, - State* state) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - // Do cooperative fetching with cache read stage - // For two stages: A -> B - // 1. A -> A_cache_read -> B - // * - // 2. A -> A_cache_write -> A_cache_read -> B - // * - if ((stage_id > 0 && HasCacheReadStage((*state), stage_id - 1) && - !HasCacheWriteStage((*state), stage_id - 1)) || - (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && - HasCacheWriteStage((*state), stage_id - 2))) { - const Stage& target_stage = (*state)->stages[stage_id]; - if (HasAnnotationIter(target_stage, IteratorAnnotation::kThreadX) || - HasAnnotationIter(target_stage, IteratorAnnotation::kTensorized)) { - // Skip if this stage has already done thread bind or has been - // tensorized - continue; - } - // Get spatial_split_step_ids from the root stage - std::unordered_set consumers; - std::vector spatial_split_step_ids; - GetConsumers(policy->cur_task, (*state), target_stage->op, &consumers); - CHECK_EQ(consumers.size(), 1); - int target_stage_id = OperationToStage(*consumers.begin(), (*state)); - GetSpaceSplitStepIds((*state), target_stage_id, &spatial_split_step_ids); - - // Fuse all axis to to do cooperative fetching - Iterator fused = state->fuse(stage_id, - (*state)->stages[stage_id]->iters); - // Left a vectorized cooperative fetching split placeholder - const auto& iters0 = state->split(stage_id, fused, {1}); - state->vectorize(stage_id, iters0[1]); - // Follow split to keep a same thread extent with the root stage - const auto& iters1 = state->follow_fused_split(stage_id, iters0[0], - spatial_split_step_ids, - 1, true); - state->bind_thread(stage_id, iters1[1], kThreadX); - } - } - - return 0; -} - -int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen) { - if (GetIntParam(policy->params, "disable_change_compute_location")) { - return 0; - } - - for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { - const Stage& stage = (*state)->stages[stage_id]; - - if (stage->op_type == kPlaceholder) { - continue; - } - - if (IsTiled(stage) || stage->compute_at == kInlined) { - continue; - } - - if (NeedsMultilevelTiling(policy->cur_task, (*state), stage->op)) { - continue; - } - - std::unordered_set consumers; - - GetConsumers(policy->cur_task, (*state), stage->op, &consumers); - if (consumers.empty()) { - continue; - } - - int target_stage_id; - if (consumers.size() == 1) { - target_stage_id = OperationToStage(*consumers.begin(), *state); - } else { - // check all consumers share a common root - int common_root_id = -1; - bool mismatch = false; - for (const auto& consumer : consumers) { - int consumer_stage_id = OperationToStage(consumer, *state); - int root_id = -1; - if ((*state)->stages[consumer_stage_id]->compute_at == kRoot) { - root_id = consumer_stage_id; - } else if ((*state)->stages[consumer_stage_id]->compute_at == kIter) { - root_id = (*state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; - } else { - LOG(FATAL) << "Invalid case"; - } - - if (common_root_id == -1) { - common_root_id = root_id; - } else { - if (common_root_id != root_id) { - mismatch = true; - break; - } - } - } - - if (mismatch) { - continue; - } - target_stage_id = common_root_id; - } - - const Stage& target_stage = (*state)->stages[target_stage_id]; - std::set to_unroll_name_set; - if (target_stage->op->attrs.count(policy->always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, - policy->always_unroll_key); - } - - std::vector > candidates; - bool target_compute_at_other = target_stage->compute_at == kIter; - bool target_is_tiled = IsTiled(target_stage); - - bool visited_reduce = false; - // enumerate compute_at location at target_stage - int ct = 0; - for (const auto& target_iter : target_stage->iters) { - if (target_iter->iter_type == kReduce) { - visited_reduce = true; - if (!target_is_tiled) { // do not go into reduce iter - break; - } - } else if (target_iter->iter_type == kSpace) { - if (visited_reduce) { // do not go into inner tile - break; - } - } - - if (to_unroll_name_set.count(target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 - continue; - } - if (target_compute_at_other && target_iter->iter_type == kSpace && - StrEndsWith(target_iter->name, ".0")) { - // skip the first level iterators if target stage compute_at another stage - // In this case, the lengths of first level iterators are always one - continue; - } - candidates.emplace_back(target_stage_id, target_iter); - - if ((*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_stage_id, ct++))) { - break; - } - } - - // if the target_stage is already compute_at another stage X, try also compute_at X - // We call stage X as `target_target_stage` - if (target_compute_at_other) { - int target_target_stage_id; - target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at( - target_stage_id).first; - const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; - if (target_target_stage->op->attrs.count(policy->always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, - policy->always_unroll_key); - } else { - to_unroll_name_set.clear(); - } - - int ct = 0; - for (const auto& target_target_iter : target_target_stage->iters) { - if (target_target_iter->iter_type == kReduce || - (*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_target_stage_id, ct++))) { - break; - } - - if (to_unroll_name_set.count(target_target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 - continue; - } - - candidates.push_back(std::make_pair(target_target_stage_id, target_target_iter)); - } - } - - int choice = (*rand_gen)() % (candidates.size() + 2); - - if (choice == 0) { - if (!HasReduceIter(stage)) { - state->compute_inline(stage_id); - } - } else if (choice == 1) { - state->compute_root(stage_id); - } else { - choice = choice - 2; - state->compute_at(stage_id, candidates[choice].first, candidates[choice].second); - } - } - - return 0; -} - -int InitPopulationParallel(const SketchSearchPolicyNode* policy, - State* state) { - std::function - annotate_parallel; - - annotate_parallel = [&annotate_parallel]( - const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { - const Stage& stage = (*state)->stages[stage_id]; - - std::vector to_fuse; - int64_t parallel_degree = 1; - - // strategy: try to fuse and parallel the outermost n iterators - // Stop if we meet reduce iterator or we have enough parallel degree - size_t iter_id = iter_offset; - for (; iter_id < stage->iters.size(); ++iter_id) { - const Iterator& it = stage->iters[iter_id]; - if (it->iter_type == kReduce || it->annotation != kNone) { - break; - } - - to_fuse.push_back(it); - parallel_degree *= GetExtent(it); - - if (parallel_degree > policy->cur_task->hardware_params->num_cores * 16) { - break; - } - - if ((*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(stage_id, iter_id))) { - break; - } - } - - if (parallel_degree == 1) { - auto res = - (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); - if (res != (*state)->attach_map->iter_to_attached_stages.end()) { - for (int attached_stage_id : res->second) { - annotate_parallel(policy, state, attached_stage_id, 0); - } - annotate_parallel(policy, state, stage_id, iter_id + 1); - } - } - - if (!to_fuse.empty()) { - if (to_fuse.size() == 1) { - state->parallel(stage_id, to_fuse[0]); - } else { - Iterator fused_iter = state->fuse(stage_id, to_fuse); - state->parallel(stage_id, fused_iter); - } - } - }; - - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { - continue; - } - - annotate_parallel(policy, state, stage_id, 0); - } - - return 0; -} - -int InitPopulationVectorization(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - - if (stage->op_type == kPlaceholder) { - continue; - } - - // Skip cooperative fetching stage - if (IS_GPU(policy->cur_task) && - HasCacheReadStage((*state), stage_id - 1)) { - continue; - } - - if (HasAnnotationIter(stage, IteratorAnnotation::kTensorized)) { - // Skip if this stage has been tensorized - continue; - } - - // try to fuse and vectorize the space iterators in the inner most tile - int cum_length_prod = 1; - - std::set to_unroll_name_set; - if (stage->op->attrs.count(policy->always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, - policy->always_unroll_key); - } - - int num_fusible = 0; - while (num_fusible < static_cast(stage->iters.size())) { - int iter_id = static_cast(stage->iters.size()) - 1 - num_fusible; - if ((*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(stage_id, iter_id))) { - break; - } - - const Iterator& it = stage->iters[iter_id]; - - // Stop if we meet a reduce iterator - if (it->iter_type == kReduce || it->annotation != kNone || - to_unroll_name_set.count(it->name)) { - break; - } - - // Stop if the memory access is not continuous (vectorizable) - // Note: The check is too hard, so we use heuristic here - if (IsTiled(stage) && num_fusible != 0) { - // If the stage is tiled, then the memory access must not be continuous - // for the innermost two iterators - break; - } - - cum_length_prod *= GetExtent(it); - if (cum_length_prod > policy->cur_task->hardware_params->max_unroll_vec) { - break; - } - - num_fusible++; - } - - if (num_fusible > 1) { - num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse - } - - if (num_fusible == 1) { - state->vectorize(stage_id, stage->iters.back()); - } else if (num_fusible > 1) { - std::vector to_fuse(stage->iters.end() - num_fusible, - stage->iters.end()); - state->vectorize(stage_id, state->fuse(stage_id, to_fuse)); - } - } - - return 0; -} - -int InitPopulationUnroll(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - - if (stage->op_type == kPlaceholder) { - continue; - } - - if (stage->op->attrs.count(policy->always_unroll_inner_key)) { - // Special unroll policy - auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, - policy->always_unroll_inner_key); - std::set visited_names; - - // Unroll the space iterators and reduce iterators listed in the attrs - // in the innermost tile - int n = static_cast(stage->iters.size()) - 1; - visited_names.clear(); - while (n >= 0) { - const Iterator& it = stage->iters[n]; - - // If we meet two iterators that come from a same original iterator, - // then we are out of the innermost tile - size_t size_before = visited_names.size(); - ExtractOriginalIterators(it->name, &visited_names); - if (size_before == visited_names.size()) { - break; - } - - std::set name; - ExtractOriginalIterators(it->name, &name); - if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) { - state->unroll(stage_id, it); - } - - n--; - } - } else if (stage->op->attrs.count(policy->always_unroll_key)) { - // Special unroll policy - auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, - policy->always_unroll_key); - - // Unroll the space iterators and reduce iterators listed in the attrs - int n = static_cast(stage->iters.size()) - 1; - while (n >= 0) { - const Iterator& it = stage->iters[n]; - if (to_unroll_name_set.count(it->name)) { - state->unroll(stage_id, it); - } - n--; - } - } else if (HasReduceIter(stage)) { - // use auto unroll for multi level tiled stage - int value = policy->auto_unroll_configs[ - (*rand_gen)() % policy->auto_unroll_configs.size()]; - state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], - std::string("auto_unroll_max_step") + "$" + std::to_string(value)); - } - } - - return 0; -} - -void SketchSearchPolicyNode::SampleInitPopulation(const std::vector& sketches, - int out_size, std::vector* out_states) { - std::uniform_real_distribution<> dis(0.0, 1.0); - int continue_count = 0; - - // TODO(...): Maybe try muti thread here - while (static_cast(out_states->size()) < out_size && - continue_count < out_size * 10) { - State tmp_s = sketches[rand_gen_() % sketches.size()]; - - InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); - - if (IS_GPU(cur_task)) { - tmp_s = cur_task->compute_dag.InferBound(tmp_s); - - if (InitPopulationThreadBind(this, &tmp_s)) { - continue_count++; - if (continue_count == out_size) { - StdCout(verbose) << "Initial Population Sampling..." << std::endl; - } - continue; - } - - InitPopulationCooperativeFetching(this, &tmp_s); - } else { - InitPopulationChangeComputeLocation(this, &tmp_s, &rand_gen_); - - tmp_s = cur_task->compute_dag.InferBound(tmp_s); - - InitPopulationParallel(this, &tmp_s); - } - - InitPopulationVectorization(this, &tmp_s, &rand_gen_); - - InitPopulationUnroll(this, &tmp_s, &rand_gen_); - - out_states->push_back(std::move(tmp_s)); - } - - StdCout(verbose) << "Sample Initial Population\t#s: " - << out_states->size() << std::endl; -} - -void SketchSearchPolicyNode::EvolutionarySearch( - const std::vector& init_population, - int num_best_states, std::vector* best_states) { - auto tic_begin = std::chrono::high_resolution_clock::now(); - - // Set parameters for genetic algorithm - int population = GetIntParam(params, "evolutionary_search_population"); - int num_iters = GetIntParam(params, "evolutionary_search_num_iters"); - double mutation_prob = GetDoubleParam(params, "evolutionary_search_mutation_prob"); - int num_cross_over = static_cast(population * 0.0); // NOT IMPLEMENTED currently - int num_cross_over_trial_upper_bound = num_cross_over * 3; - CostModel cost_model = program_cost_model; - - // Two ping pong buffers to avoid copy - std::vector states_buf1, states_buf2; - std::vector *pnow = &states_buf1, *pnext = &states_buf2; - states_buf1.reserve(population); - states_buf2.reserve(population); - states_buf1.insert(states_buf1.begin(), init_population.begin(), init_population.end()); - - // A heap to keep the best states during evolution - using StateItem = std::pair; - auto cmp = [](const StateItem& left, const StateItem& right) { - return left.second > right.second; - }; - std::vector heap; - std::unordered_set in_heap(measured_states_set_); - heap.reserve(num_best_states); - - // auxiliary global variables - std::vector scores; - std::vector prefix_sum_probs; - double max_score = 0.0; - scores.reserve(population); - prefix_sum_probs.reserve(population); - std::uniform_real_distribution<> dis(0.0, 1.0); - int mutation_fail_ct = 0; - - // Genetic Algorithm - for (int k = 0; k < num_iters + 1; ++k) { - // Maintain the heap - cur_task->compute_dag.InferBound(pnow); - PruneUndefined(pnow); - cost_model->Predict(cur_task, *pnow, &scores); - - for (size_t i = 0; i < pnow->size(); ++i) { - const State& state = (*pnow)[i]; - std::string state_str = state.ToStr(); - - if (in_heap.count(state_str) == 0) { - if (static_cast(heap.size()) < num_best_states) { - heap.emplace_back((*pnow)[i], scores[i]); - std::push_heap(heap.begin(), heap.end(), cmp); - in_heap.insert(state_str); - } else if (scores[i] > heap.front().second) { - std::string old_state_str = heap.front().first.ToStr(); - in_heap.erase(old_state_str); - in_heap.insert(state_str); - - std::pop_heap(heap.begin(), heap.end(), cmp); - heap.back() = StateItem(state, scores[i]); - std::push_heap(heap.begin(), heap.end(), cmp); - } - if (scores[i] > max_score) { - max_score = scores[i]; - } - } - } - - if (k % 5 == 0 || k == num_iters) { - StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4) - << "\tMax score: " << max_score - << "\tMin score: " << heap.front().second - << "\tPop size: " << pnow->size() << std::endl; - } - - if (k == num_iters) { - break; - } - - // Compute selection probability - double sum = 0.0; - prefix_sum_probs.resize(scores.size()); - for (size_t i = 0; i < scores.size(); ++i) { - sum += std::max(scores[i], 0.0f); - prefix_sum_probs[i] = sum; - } - for (size_t i = 0; i < scores.size(); ++i) { - prefix_sum_probs[i] = prefix_sum_probs[i] / sum; - } - - // Do cross over - int ct = 0; - while (static_cast(pnext->size()) < num_cross_over - && ct < num_cross_over_trial_upper_bound) { - int p1 = RandomChoose(prefix_sum_probs, &rand_gen_); - int p2 = RandomChoose(prefix_sum_probs, &rand_gen_); - - if (p1 == p2) { - pnext->push_back((*pnow)[p1]); - } else { - State tmp_s = CrossOverState((*pnow)[p1], (*pnow)[p2]); - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } - } - ct++; - } - - // Do mutation - mutation_fail_ct = 0; - while (static_cast(pnext->size()) < population) { - int id = RandomChoose(prefix_sum_probs, &rand_gen_); - - if (dis(rand_gen_) < mutation_prob) { - const std::vector rule_prefix_sum_probs{0.9, 1.0}; - - int rule_id = RandomChoose(rule_prefix_sum_probs, &rand_gen_); - - if (rule_id == 0) { - // Mutate Tile Size - State tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, - cur_task->hardware_params->max_innermost_split_factor); - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } else { - mutation_fail_ct++; - } - } else if (rule_id == 1) { - // Mutate auto-unroll max step. - State tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } else { - mutation_fail_ct++; - } - } - } else { - pnext->push_back((*pnow)[id]); - } - } - - std::swap(pnext, pnow); pnext->clear(); - } - - // Copy best states in the heap to out_states - std::sort(heap.begin(), heap.end(), cmp); - best_states->clear(); - for (auto& item : heap) { - best_states->push_back(std::move(item.first)); - } - - double duration = std::chrono::duration_cast >( - std::chrono::high_resolution_clock::now()- tic_begin).count(); - StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states->size() - << "\tTime elapsed: " - << std::fixed << std::setprecision(2) << duration << std::endl; -} - -class RuleCustomSketch : public SketchGenerationRule { - public: - RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : - meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} - - inline ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - auto ret = meet_condition_func_( - tvm::runtime::GetRef(policy), state, stage_id); - if (ret.type_code() == 0) { - return ConditionEnum(static_cast(ret)); - } else { - return kApplyAndSkipRest; - } - } - - inline std::vector > Apply( - const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - std::vector > ret; - - Array> apply_ret = apply_func_( - tvm::runtime::GetRef(policy), state, stage_id); - - for (const auto& item : apply_ret) { - CHECK_EQ(item.size(), 2); - State state = Downcast(item[0]); - auto next = item[1].as(); - ret.emplace_back(state, next->value); - } - return ret; - } - - private: - PackedFunc meet_condition_func_; - PackedFunc apply_func_; -}; - -PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, - PackedFunc apply_func) { - auto node = make_object(); - node->meet_condition_func = meet_condition_func; - node->apply_func = apply_func; - data_ = std::move(node); -} - -void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { - CHECK(policy->IsInstance()); - auto sketch_policy = dynamic_cast(policy); - sketch_policy->sketch_rules.emplace_back( - new RuleCustomSketch(meet_condition_func, apply_func)); - StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; -} - -TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") -.set_body_typed([](CostModel program_cost_model, Map params, - int seed){ - return SketchSearchPolicy(program_cost_model, params, seed); -}); - -TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") -.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreloadCustomSketchRule(meet_condition_func, apply_func); -}); - -} // namespace ansor -} // namespace tvm diff --git a/src/ansor/search_policy/sketch_search_policy.h b/src/ansor/search_policy/sketch_search_policy.h deleted file mode 100644 index 54a5cdd1fa4e..000000000000 --- a/src/ansor/search_policy/sketch_search_policy.h +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/sketch_search_policy.h - * \brief The search policy that searches in a hierarchical search space defined by sketches. - * The policy randomly samples programs from the space defined by sketches - * and use evolutionary search to fine-tune them. - */ - -#ifndef TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ - -#include -#include -#include -#include -#include -#include "search_policy.h" -#include "../cost_model/cost_model.h" -#include "../utils.h" - - -namespace tvm { -namespace ansor { - -class SketchGenerationRule; - -/*! - * \brief The search policy that searches in a hierarchical search space defined by sketches. - * The policy randomly samples programs from the space defined by sketches - * and use evolutionary search to fine-tune them. - */ -class SketchSearchPolicyNode: public SearchPolicyNode { - public: - /*! \brief The cost model for complete programs */ - CostModel program_cost_model; - /*! \brief Random generator */ - std::mt19937 rand_gen_; - /*! \brief The parameters for search. It stores the following parameters: - * int evolutionary_search_population // The population size for evolutionary search - * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search - * int evolutionary_search_num_iters; // The number of iterations for evolutionary search - * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial - * // population for evolutionary search - * double eps_greedy; // Always allocate this percentage of measurements to random sampled states - * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU - * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU - */ - Map params; - /*! \brief The rules to generate sketches */ - std::vector sketch_rules; - - /*! \brief Search and make n_trails measurements. - * \returns the best state */ - State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) final; - - /*! \brief Continue search for one round. This is used by JointTuner - * \returns the measurement pairs */ - std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - - static constexpr const char *_type_key = "ansor.SketchSearchPolicy"; - static const std::vector auto_unroll_configs; - - TVM_DECLARE_FINAL_OBJECT_INFO(SketchSearchPolicyNode, SearchPolicyNode); - - protected: - /*! \brief Pick states from best states and random states with eps-greedy policy */ - void PickStatesWithEpsGreedy(std::vector* inputs, - const std::vector& best_states, - const std::vector& random_states, - int remaining_n_trials); - - private: - // Run one round of the search pipeline - void SearchOneRound(std::vector* best_states, - int num_random_states, std::vector* random_states); - - // Generate sketches without tile size - void GenerateSketch(std::vector* out_states); - - // Sample init population - void SampleInitPopulation(const std::vector& sketches, - int out_size, std::vector* out_states); - - // Perform evolutionary search - void EvolutionarySearch(const std::vector& init_population, - int num_best_states, std::vector* best_states); - - SplitFactorizationMemo split_memo_; // Memorize split space for Split - int num_measure_per_iter_; // The number of states to measure per iteration -}; - -/*! - * \brief Managed reference to SketchSearchPolicyNode. - * \sa SketchSearchPolicyNode - */ -class SketchSearchPolicy : public SearchPolicy { - public: - SketchSearchPolicy(CostModel program_cost_model, - Map params, - int seed); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchSearchPolicy, SearchPolicy, - SketchSearchPolicyNode); -}; - -/*! \brief Pre-search callback function to load custom rules for sketch generation */ -class PreloadCustomSketchRuleNode : public SearchCallbackNode { - public: - // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? - PackedFunc meet_condition_func; - PackedFunc apply_func; - - void callback(SearchPolicyNode* policy) final; - - static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); -}; - -/*! - * \brief Managed reference to PreloadCustomSketchRuleNode. - * \sa PreloadCustomSketchRuleNode - */ -class PreloadCustomSketchRule : public SearchCallback { - public: - PreloadCustomSketchRule(PackedFunc meet_condition_func, - PackedFunc apply_func); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, - PreloadCustomSketchRuleNode); -}; - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc deleted file mode 100644 index 2d2f92ecbc20..000000000000 --- a/src/ansor/search_policy/utils.cc +++ /dev/null @@ -1,744 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/utils.cc - * \brief Common utilities for search policies - */ - -#include "utils.h" -#include "search_policy.h" - -namespace tvm { -namespace ansor { - -void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids) { - auto pop = s->stages[stage_id]->op.as(); - CHECK(pop != nullptr); - - const auto& no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); - const std::set& no_split_at_inner_name_set = no_split_name_pair.first; - const std::set& no_split_at_outer_name_set = no_split_name_pair.second; - - size_t reduce_count = 0; - for (const auto axis : pop->reduce_axis) { - if (!no_split_at_inner_name_set.count(axis->var->name_hint) && - !no_split_at_outer_name_set.count(axis->var->name_hint)) { - reduce_count++; - } - } - - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance()) { - if (stage_id > s->transform_steps[i]->stage_id) { - stage_id--; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id == ps->stage_id) { - // Assume SplitStep on reduction axes are always after SplitStep on spatial axes. - // TODO(jcf94): do not rely on this assumption - if (reduce_count) { - reduce_count--; - } else { - spatial_split_step_ids->push_back(i); - } - } - } - } -} - -State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, - std::vector* spatial_split_step_ids) { - std::vector > space_levels; - std::vector > reduce_levels; - std::vector space_outer, space_inner, reduce_outer, reduce_inner; - std::vector split_res; - - for (const auto c : format) { - if (tolower(c) == 's') { - space_levels.emplace_back(); - } else if (tolower(c) == 'r') { - reduce_levels.emplace_back(); - } else { - LOG(FATAL) << "Invalid multi-level tiling format: " << format; - } - } - size_t n_space = space_levels.size(); - size_t n_reduce = reduce_levels.size(); - - spatial_split_step_ids->clear(); - - State tmp_s = state; - const Stage& stage = state->stages[stage_id]; - const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy - const auto& last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); - const std::set& no_split_at_inner_name_set = no_split_name_pair.first; - const std::set& no_split_at_outer_name_set = no_split_name_pair.second; - - for (const auto& iter : state->stages[stage_id]->iters) { - if (iter->iter_type == kSpace) { - if (!no_split_at_inner_name_set.count(iter->name) && - !no_split_at_outer_name_set.count(iter->name)) { - CHECK_GE(n_space, 1); - int tmp_n_space = n_space; - - if (last_split_is_one_name_set.count(iter->name)) { - tmp_n_space--; - } - - if (tmp_n_space == 1) { - space_levels[0].push_back(iter); - } else { - split_res = tmp_s.split(stage_id, iter, std::vector(tmp_n_space - 1)); - for (int i = 0; i < tmp_n_space; i++) { - space_levels[i].push_back(std::move(split_res[i])); - } - spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); - } - } else { - if (no_split_at_inner_name_set.count(iter->name)) { - space_inner.push_back(iter); - } - if (no_split_at_outer_name_set.count(iter->name)) { - space_outer.push_back(iter); - } - } - } else if (iter->iter_type == kReduce) { - if (!no_split_at_inner_name_set.count(iter->name) && - !no_split_at_outer_name_set.count(iter->name)) { - CHECK_GE(n_reduce, 1); - - if (n_reduce == 1) { - reduce_levels[0].push_back(iter); - } else { - split_res = tmp_s.split(stage_id, iter, std::vector(n_reduce - 1)); - for (size_t i = 0; i < n_reduce; i++) { - reduce_levels[i].push_back(std::move(split_res[i])); - } - } - } else { - if (no_split_at_inner_name_set.count(iter->name)) { - reduce_inner.push_back(iter); - } - if (no_split_at_outer_name_set.count(iter->name)) { - reduce_outer.push_back(iter); - } - } - } else { - LOG(FATAL) << "Invalid iter type: " << iter->iter_type; - } - } - - if (!space_outer.empty()) { - CHECK(!space_levels.empty()); - space_levels.front().insert(space_levels.front().begin(), - std::make_move_iterator(space_outer.begin()), - std::make_move_iterator(space_outer.end())); - } - if (!space_inner.empty()) { - CHECK(!space_levels.empty()); - space_levels.back().insert(space_levels.back().begin(), - std::make_move_iterator(space_inner.begin()), - std::make_move_iterator(space_inner.end())); - } - - if (!reduce_outer.empty()) { - CHECK(!reduce_levels.empty()); - reduce_levels.front().insert(reduce_levels.front().begin(), - std::make_move_iterator(reduce_outer.begin()), - std::make_move_iterator(reduce_outer.end())); - } - if (!reduce_inner.empty()) { - CHECK(!reduce_levels.empty()); - reduce_levels.back().insert(reduce_levels.back().begin(), - std::make_move_iterator(reduce_inner.begin()), - std::make_move_iterator(reduce_inner.end())); - } - - std::vector order; - int space_ct = 0, reduce_ct = 0; - for (const auto c : format) { - if (tolower(c) == 's') { - order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), - std::make_move_iterator(space_levels[space_ct].end())); - space_ct++; - } else if (tolower(c) == 'r') { - order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), - std::make_move_iterator(reduce_levels[reduce_ct].end())); - reduce_ct++; - } else { - LOG(FATAL) << "Invalid multi level tiling format: " << format; - } - } - - tmp_s.reorder(stage_id, order); - return tmp_s; -} - -State FollowTiling(const State& state, int stage_id, - const std::vector& split_step_ids, int n_split) { - if (n_split < 1 || n_split > 3) { - LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3"; - } - // Apply up to three-level tiling structure: space_L0, space_L1, space_L2 - std::vector space_0, space_1, space_2, space_3; - std::vector split_res, tmp_order; - - auto pop = state->stages[stage_id]->op.as(); - CHECK(pop != nullptr); - const Stage& stage = state->stages[stage_id]; - const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy - const std::set& no_split_at_inner_name_set = no_split_name_pair.first; - const std::set& no_split_at_outer_name_set = no_split_name_pair.second; - int no_split_at_inner_name_in_stage_cnt = 0; - int no_split_at_outer_name_in_stage_cnt = 0; - for (const auto& iter : state->stages[stage_id]->iters) { - no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name); - no_split_at_outer_name_in_stage_cnt += no_split_at_outer_name_set.count(iter->name); - } - - CHECK_EQ(state->stages[stage_id]->iters.size() - - no_split_at_inner_name_in_stage_cnt - - no_split_at_outer_name_in_stage_cnt, - split_step_ids.size()); - - State tmp_s = state; - int ct = 0; - for (const auto& iter : state->stages[stage_id]->iters) { - if (iter->iter_type == kSpace) { - // For spatial iterator, split it into multi iterators - if (!no_split_at_inner_name_set.count(iter->name) && - !no_split_at_outer_name_set.count(iter->name)) { - IteratorAnnotation ann_type = iter->annotation; - split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], - n_split); - // Restore annotation. Move unroll and vectorize to inner, move parallel - // to outer - switch (ann_type) { - case kUnroll: - split_res[n_split] = tmp_s.unroll(stage_id, split_res[n_split]); - break; - case kVectorize: - split_res[n_split] = tmp_s.vectorize(stage_id, split_res[n_split]); - break; - case kParallel: - split_res[0] = tmp_s.parallel(stage_id, split_res[0]); break; - default: - break; - } - - space_0.push_back(std::move(split_res[0])); - space_1.push_back(std::move(split_res[1])); - if (n_split >= 2) { - space_2.push_back(std::move(split_res[2])); - if (n_split == 3) { - space_3.push_back(std::move(split_res[3])); - } - } - ct++; - } else { - if (no_split_at_outer_name_set.count(iter->name)) { - space_0.push_back(iter); - } - if (no_split_at_inner_name_set.count(iter->name)) { - if (n_split == 1) { - space_1.push_back(iter); - } else if (n_split == 2) { - space_2.push_back(iter); - } else { - CHECK_EQ(n_split, 3); - space_3.push_back(iter); - } - } - } - } else { - LOG(FATAL) << "Invalid iter type: " << iter->iter_type; - } - } - - if (n_split == 3) { - ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); - } else if (n_split == 2) { - ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2); - } else { - ConcatenateMove(&tmp_order, &space_0, &space_1); - } - tmp_s.reorder(stage_id, tmp_order); - return tmp_s; -} - -State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, int max_innermost_split_factor) { - State tmp_s = old_state; - - // Extract all SplitStep - std::vector split_step_ids; - for (size_t i = 0; i < tmp_s->transform_steps.size(); ++i) { - if (auto ps = tmp_s->transform_steps[i].as()) { - if (ps->extent.defined() && ps->extent->IsInstance() && - GetIntImm(ps->lengths.back()) <= max_innermost_split_factor) { - split_step_ids.push_back(i); - } - } - } - if (split_step_ids.empty()) { - return State(); - } - - // Find a SplitStep with extent != 1 - int retry_ct = 0; - int64_t extent = 1; - int step_id; - const SplitStepNode* ps; - - do { - step_id = split_step_ids[(*random_gen)() % split_step_ids.size()]; - ps = tmp_s->transform_steps[step_id].as(); - CHECK(ps != nullptr); - extent = GetIntImm(ps->extent); - retry_ct += 1; - } while (retry_ct < static_cast(split_step_ids.size()) << 2 && - (extent == 1 || extent == 0)); - - if (extent == 0 || extent == 1) { - return State(); - } - - // Mutate tile size - std::vector lengths(ps->lengths.size() + 1, 1); - for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { - lengths[i + 1] = GetIntImm(ps->lengths[i]); - } - lengths[0] = extent / ElementProduct(lengths); - - std::vector random_perm; - RandomPermutation(lengths.size(), &random_perm, random_gen); - - for (size_t i = 0; i < random_perm.size(); ++i) { - size_t src_idx = random_perm[i]; - int length = lengths[src_idx]; - - if (length == 1) { - continue; - } - - // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] - size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; - - const std::vector& factors = split_memo->GetFactors(length); - CHECK_GE(factors.size(), 1); - - int divide_factor; - if (dst_idx == lengths.size() - 1) { - // Maintain the restriction of hardware_params.max_innermost_split_factor - int max_factor_index = static_cast(factors.size()) - 1; - for (; max_factor_index >= 1; max_factor_index--) { - if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { - break; - } - } - if (max_factor_index == 0) { - // failed on this dst_idx, try next one - continue; - } - divide_factor = factors[1 + (*random_gen)() % (max_factor_index)]; - } else { - divide_factor = factors[1 + (*random_gen)() % (factors.size() - 1)]; - } - - std::vector new_lengths; - for (size_t j = 1; j < lengths.size(); ++j) { - if (j == src_idx) { - new_lengths.emplace_back(lengths[j] / divide_factor); - } else if (j == dst_idx) { - new_lengths.emplace_back(lengths[j] * divide_factor); - } else { - new_lengths.emplace_back(lengths[j]); - } - } - - CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); - - auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = - SplitStep(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); - return tmp_s; - } - - return State(); -} - -State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, - const std::vector& auto_unroll_configs) { - State tmp_s = old_state; - - // Extract all auto_unroll_max_step pragma steps. - std::vector annotate_steps; - for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { - if (auto ps = tmp_s->transform_steps[i].as()) { - if (ps->pragma_type.find("auto_unroll_max_step") != std::string::npos) { - annotate_steps.push_back(i); - } - } - } - if (annotate_steps.empty()) { - return State(); - } - - // Randomly pick one step. - auto step_id = annotate_steps[(*random_gen)() % annotate_steps.size()]; - auto ps = tmp_s->transform_steps[step_id].as(); - auto val = std::to_string(auto_unroll_configs[(*random_gen)() % auto_unroll_configs.size()]); - - auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = PragmaStep( - ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val); - return tmp_s; -} - -State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, - const SearchTask& task, int verbose) { - // To make this mutation simple but promising, we only focus on a specific case that - // parallel was added to the outermost loop and the loop is generated by fusing other loops. - // In short, we mutate the step pattern of (fuse -> parallel). - - // Extract all parallel steps. - std::vector parallel_steps; - for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { - auto ps = old_state->transform_steps[s].as(); - if (!ps || ps->annotation != kParallel) { - continue; - } - parallel_steps.push_back(s); - } - if (parallel_steps.empty()) { - StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; - return State(); - } - - // Randomly pick one step. - int retry_ct = 0; - size_t step_id = 0; - size_t stage_id = 0; - do { - step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; - auto step = old_state->transform_steps[step_id].as(); - stage_id = step->stage_id; - - // Check assumptions. - auto iter_id = step->iter_id; - if (iter_id == 0 && step_id > 0 && old_state->transform_steps[step_id - 1].as()) { - break; - } - retry_ct++; - } while (retry_ct <= 3); - - if (retry_ct > 3) { - StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; - return State(); - } - - // Replay a new state until the picked fuse step. - State tmp_s = task->compute_dag.GetInitState(); - for (size_t s = 0; s < step_id - 1; ++s) { - auto step = old_state->transform_steps[s]; - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - tmp_s.DoStep(step, task->compute_dag); - } - - // Determine the fuse direction. - // 0: fuse less; 1: fuse more. - auto fuse_step = old_state->transform_steps[step_id - 1].as(); - std::vector fused_ids = fuse_step->fused_ids; - std::vector fuse_dir = {0.5, 1.0}; - - // The case we can only fuse more. - if (fused_ids.size() == 1) { - fuse_dir[0] = 0.0; - } - - // The cases that we cannot fuse the next iters. - if (old_state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0 || - tmp_s->stages[stage_id]->iters.size() == fused_ids.size() || - tmp_s->stages[stage_id]->iters[1]->iter_type == kReduce) { - // In case we cannot fuse less neither, give up. - if (fuse_dir[0] == 0.0) { - StdCout(verbose) << "Parallel mutation failed: Cannot fuse more or less iters" << std::endl; - return State(); - } - fuse_dir[0] = 1.0; - } - - int iter_offset = 0; - if (RandomChoose(fuse_dir, random_gen) == 0) { - StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; - fused_ids.pop_back(); - iter_offset = 1; - } else { - StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; - fused_ids.push_back(fused_ids.back() + 1); - iter_offset = -1; - } - - // Replay the mutated fused and annotation step. - auto new_fuse_step = FuseStep(stage_id, fused_ids); - tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); - tmp_s.DoStep(new_fuse_step, task->compute_dag); - tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[step_id]); - tmp_s.DoStep(old_state->transform_steps[step_id], task->compute_dag); - - // Replay the rest steps. - for (size_t s = step_id + 1; s < old_state->transform_steps.size(); ++s) { - auto step = old_state->transform_steps[s]; - if (step->stage_id == static_cast(stage_id)) { - // Since we change the loop structure, iter ID in later steps to the same stage - // has to be adjusted. - auto ps = step.as(); - if (ps) { - if (ps->iter_id == 0) { - step = AnnotationStep(ps->stage_id, 0, ps->annotation); - } else { - CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); - step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); - } - } else { - StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" - << std::endl; - return State(); - } - } - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - tmp_s.DoStep(step, task->compute_dag); - } - return tmp_s; -} - - -State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, - const SearchTask& task) { - // Extract all compute_at steps. - std::vector compute_at_steps; - for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { - if (auto ps = old_state->transform_steps[s].as()) { - const Stage& stage = old_state->stages[ps->stage_id]; - if (IsTiled(stage)) { - continue; - } - - if (NeedsMultilevelTiling(task, old_state, stage->op)) { - continue; - } - compute_at_steps.push_back(s); - } - } - if (compute_at_steps.empty()) { - return State(); - } - - // Randomly pick one step - size_t step_id = compute_at_steps[(*random_gen)() % compute_at_steps.size()]; - auto ps = old_state->transform_steps[step_id].as(); - CHECK(ps != nullptr); - const Stage& stage = old_state->stages[ps->stage_id]; - - // Randomly pick one tile level - int new_compute_at_stage_id; - int new_compute_at_iter_id; - - // Copied from InitPopulationChangeComputeLocation - { - std::unordered_set consumers; - GetConsumers(task, old_state, stage->op, &consumers); - if (consumers.empty()) { - return State(); - } - - int target_stage_id; - if (consumers.size() == 1) { - target_stage_id = OperationToStage(*consumers.begin(), old_state); - } else { - // check all consumers share a common root - int common_root_id = -1; - bool mismatch = false; - for (const auto& consumer : consumers) { - int consumer_stage_id = OperationToStage(consumer, old_state); - int root_id = -1; - if ((old_state)->stages[consumer_stage_id]->compute_at == kRoot) { - root_id = consumer_stage_id; - } else if ((old_state)->stages[consumer_stage_id]->compute_at == kIter) { - root_id = (old_state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; - } else { - LOG(FATAL) << "Invalid case"; - } - - if (common_root_id == -1) { - common_root_id = root_id; - } else { - if (common_root_id != root_id) { - mismatch = true; - break; - } - } - } - - if (mismatch) { - return State(); - } - target_stage_id = common_root_id; - } - - const Stage& target_stage = old_state->stages[target_stage_id]; - std::set to_unroll_name_set; - if (target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, - SearchPolicyNode::always_unroll_key); - } - - std::vector > candidates; - bool target_compute_at_other = target_stage->compute_at == kIter; - bool target_is_tiled = IsTiled(target_stage); - - bool visited_reduce = false; - // enumerate compute_at location at target_stage - int ct = 0; - for (size_t iter_id = 0; iter_id < target_stage->iters.size(); ++iter_id) { - const auto& target_iter = target_stage->iters[iter_id]; - if (target_iter->iter_type == kReduce) { - visited_reduce = true; - if (!target_is_tiled) { // do not go into reduce iter - break; - } - } else if (target_iter->iter_type == kSpace) { - if (visited_reduce) { // do not go into inner tile - break; - } - } - - if (to_unroll_name_set.count(target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 - continue; - } - if (target_compute_at_other && target_iter->iter_type == kSpace && - StrEndsWith(target_iter->name, ".0")) { - // skip the first level iterators if target stage compute_at another stage - // In this case, the lengths of first level iterators are always one - continue; - } - candidates.emplace_back(target_stage_id, iter_id); - - if ((old_state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_stage_id, ct++))) { - break; - } - } - - // if the target_stage is already compute_at another stage X, try also compute_at X - // We call stage X as `target_target_stage` - if (target_compute_at_other) { - int target_target_stage_id; - target_target_stage_id = (old_state)->attach_map->stage_to_attach_iter.at( - target_stage_id).first; - const Stage& target_target_stage = (old_state)->stages[target_target_stage_id]; - if (target_target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, - SearchPolicyNode::always_unroll_key); - } else { - to_unroll_name_set.clear(); - } - - int ct = 0; - for (size_t iter_id = 0; iter_id < target_target_stage->iters.size(); ++iter_id) { - const auto& target_target_iter = target_target_stage->iters[iter_id]; - if (target_target_iter->iter_type == kReduce || - (old_state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_target_stage_id, ct++))) { - break; - } - - if (to_unroll_name_set.count(target_target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 - continue; - } - - candidates.emplace_back(target_target_stage_id, iter_id); - } - } - - if (candidates.empty()) { - return State(); - } - - int choice = (*random_gen)() % (candidates.size()); - new_compute_at_stage_id = candidates[choice].first; - new_compute_at_iter_id = candidates[choice].second; - } - - // Replay a new state. - State tmp_s = task->compute_dag.GetInitState(); - for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { - if (s == step_id) { - tmp_s.CopyOnWrite()->transform_steps.push_back( - ComputeAtStep(ps->stage_id, new_compute_at_stage_id, new_compute_at_iter_id)); - } else { - tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[s]); - } - try { - tmp_s.DoStep(tmp_s->transform_steps.back(), task->compute_dag); - } catch (dmlc::Error &e) { - return State(); - } - } - - return tmp_s; -} - -void PruneUndefined(std::vector* states) { - size_t pt = 0; - for (size_t i = 0; i < states->size(); ++i) { - if (!(*states)[i].defined()) { - continue; - } - if (i != pt) { - (*states)[pt++] = std::move((*states)[i]); - } - pt++; - } - - if (pt == 0) { - LOG(FATAL) << "All states are undefined."; - } else { - states->resize(pt); - } -} - -State CrossOverState(const State& p1, const State& p2) { return State(); } - -} // namespace ansor -} // namespace tvm - diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h deleted file mode 100644 index 107e2ee72521..000000000000 --- a/src/ansor/search_policy/utils.h +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/utils.cc - * \brief Common utilities for search policies - */ - -#ifndef TVM_ANSOR_SEARCH_POLICY_UTILS_H_ -#define TVM_ANSOR_SEARCH_POLICY_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include "../cost_model/cost_model.h" -#include "../utils.h" -#include "../loop_state.h" -#include "../transform_step.h" -#include "search_policy.h" - -namespace tvm { -namespace ansor { - -// Get an integer from a tvm str Map -inline int GetIntParam(const Map& attr_dict, - const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pint = attr_dict[key].as(); - CHECK(pint != nullptr); - return pint->value; -} - -// Get a double from a tvm str Map -inline double GetDoubleParam(const Map& attr_dict, - const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pdouble = attr_dict[key].as(); - CHECK(pdouble != nullptr); - return pdouble->value; -} - -// Get a string from a tvm str Map -inline std::string GetStringParam(const Map& attr_dict, - const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) - << "Cannot find key: \"" << key << "\" in " << attr_dict; - const auto& target = attr_dict[key]; - if (auto pstr = target.as()) { - return pstr->value; - } - auto pstr = target.as(); - CHECK(pstr != nullptr); - return pstr->data; -} - -// Get a iterator name set from a tvm str Map -inline std::set GetIterNameSetParam(const Map& attr_dict, - const std::string& key) { - std::set ret; - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto names = attr_dict[key].as(); - CHECK(names != nullptr); - for (const auto & name : *names) { - ret.insert(name.as()->value); - } - return ret; -} - -// Convert operation to stage id -inline int OperationToStage(const te::Operation& op, const State& state) { - for (size_t i = 0; i < state->stages.size(); ++i) { - if (op == state->stages[i]->op) { - return i; - } - } - LOG(FATAL) << "Cannot find op: " << op; - return -1; -} - -// Return the extent of an iterator -inline int64_t GetExtent(const Iterator& it) { - if (it->range.defined()) { - if (auto pint = it->range->extent.as()) { - return pint->value; - } - } - return -1; -} - -// Return whether an op is strict inlineable -inline bool IsStrictInlineable(const SearchTask& task, - const State& state, const te::Operation& op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.IsStrictInlineable(op); - } else { - return task->compute_dag->access_analyzer.IsStrictInlineable(op); - } -} - -// Return whether an op is an output op -inline bool IsOutputOp(const SearchTask& task, const State& state, const te::Operation& op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.IsOutput(op); - } else { - return task->compute_dag->access_analyzer.IsOutput(op); - } -} - -// Return whether the stage has an attribute flag -inline bool HasAttrsFlag(const State& state, int stage_id, const char* target) { - if (state->stages[stage_id]->op->attrs.count(target)) { - return GetStringParam(state->stages[stage_id]->op->attrs, target) == "True"; - } - return false; -} - -// Return whether the stage has reduce iterators -inline bool HasReduceIter(const Stage& stage) { - for (const auto& iter : stage->iters) { - if (iter->iter_type != kSpace) { - return true; - } - } - return false; -} - -// Return whether the stage has specific annotated iterators -inline bool HasAnnotationIter(const Stage& stage, IteratorAnnotation type) { - for (const auto& iter : stage->iters) { - if (iter->annotation == type) { - return true; - } - } - return false; -} - -// Return whether an op needs multi level tiling -inline bool NeedsMultilevelTiling(const SearchTask& task, - const State& state, const te::Operation& op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.NeedsMultiLevelTiling(op); - } else { - return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(op); - } -} - -// Get all consumers for an op. This will take inline into consideration -inline void GetConsumers(const SearchTask& task, const State& state, const te::Operation& op, - std::unordered_set* consumers) { - if (state->task_dag.defined()) { - state->task_dag->access_analyzer.GetConsumers(state, op, consumers); - } else { - task->compute_dag->access_analyzer.GetConsumers(state, op, consumers); - } -} - -inline void GetProducers(const SearchTask& task, const State& state, const te::Operation& op, - std::unordered_set* producers) { - if (state->task_dag.defined()) { - state->task_dag->access_analyzer.GetProducers(state, op, producers); - } else { - task->compute_dag->access_analyzer.GetProducers(state, op, producers); - } -} - -// Return whether two ops are elementwise-matched -inline bool ElementwiseMatch(const SearchTask& task, const State& state, const te::Operation& op, - const te::Operation& target_op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.ElementWiseMatch(op, target_op); - } else { - return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); - } -} - -// Return whether the stage has only one consumer and they are elementwise-matched -inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, - const State& state, const Stage& stage, int* target_stage_id) { - std::unordered_set consumers; - - GetConsumers(task, state, stage->op, &consumers); - if (consumers.size() == 1) { - *target_stage_id = OperationToStage(*consumers.begin(), state); - const Stage& target_stage = state->stages[*target_stage_id]; - if (ElementwiseMatch(task, state, stage->op, target_stage->op) && - (!(HasReduceIter(stage) && HasReduceIter(target_stage)))) { - return true; - } - } - return false; -} - -// Return whether this stage needs rfactor -inline bool NeedsRfactor(const SearchTask& task, const State& state, const te::Operation& op) { - if (op->IsInstance()) { - // Compute the product of lengths of all space iters and all reduce iters - int64_t cum_space_len = 1, cum_reduce_len = 1; - int stage_id = OperationToStage(op, state); - for (const auto& iter : state->stages[stage_id]->iters) { - if (iter->iter_type == kSpace) { - cum_space_len *= GetExtent(iter); - } else if (iter->iter_type == kReduce) { - cum_reduce_len *= GetExtent(iter); - } - } - - if (NeedsMultilevelTiling(task, state, op)) { - // Do not use rfactor if we have enough parallelism on space iters - if (cum_space_len > cum_reduce_len || - cum_space_len > task->hardware_params->num_cores * 16) { - return false; - } else { - return true; - } - } else if (cum_reduce_len > 1) { - // Always try rfactor for reduction ops - return true; - } - } - - return false; -} - -// Return whether the state did cache_write for stage_id -inline bool HasCacheWriteStage(const State& s, int stage_id) { - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } else if (stage_id == ps->stage_id) { - return true; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } - } - return false; -} - -// Return whether the state did cache_read for stage_id -inline bool HasCacheReadStage(const State& s, int stage_id) { - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } else if (stage_id == ps->stage_id) { - return true; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } - } - return false; -} - -// Return whether the state did split/follow_split/follow_fused_split in stage_id -inline bool HasSplitStep(const State& s, int stage_id) { - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance()) { - if (stage_id > s->transform_steps[i]->stage_id) { - stage_id--; - } - } else if (s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance()) { - if (stage_id == s->transform_steps[i]->stage_id) { - return true; - } - } - } - return false; -} - -// Return whether the stage has been tiled already -inline bool IsTiled(const Stage& stage) { - auto op = stage->op.as(); - CHECK(op != nullptr); - return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); -} - -// Query axes that should not be splitted according to the attribute from tvm.compute -inline std::pair, std::set > QueryNoSplitAxis( - const Stage& stage) { - std::pair, std::set > ret; - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { - ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); - } - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { - ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); - } - return ret; -} - -// Query axes that last split is one -inline std::set QueryLastSplitIsOneAxis(const Stage& stage) { - std::set ret; - if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { - ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); - } - return ret; -} - -// Extract primitive iterators from a nested fused or splitted iterator's name -inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { - size_t last_pos = 0; - for (size_t i = 0; i < name.size(); ++i) { - if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split - if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { - rets->insert(name.substr(last_pos, i - last_pos)); - } - last_pos = i + 1; - } - } - - if (last_pos < name.size() && !isdigit(name[last_pos]) && - name[last_pos] != '@' && name[last_pos] != '.') { - rets->insert(name.substr(last_pos, name.size() - last_pos)); - } -} - -// Get the last space iterator in the outer most tile -inline const Iterator& GetLastSpaceIteratorInOutermostTile(const Stage& stage) { - auto pop = stage->op.as(); - CHECK(pop != nullptr); - std::set original_names; - - for (const auto& iter : stage->iters) { - ExtractOriginalIterators(iter->name, &original_names); - if (original_names.size() == pop->axis.size()) { - return iter; - } - } - - LOG(FATAL) << "Cannot find the iterator."; - return stage->iters[0]; -} - -// Get the last reduce iterator in the outermost reduce tile -inline const Iterator& GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { - auto pop = stage->op.as(); - CHECK(pop != nullptr); - std::set original_names; - - auto no_split_name_pair = QueryNoSplitAxis(stage); - std::set no_split_at_inner_name_set = no_split_name_pair.first; - size_t axis_size = 0; - for (const auto axis : pop->axis) { - if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { - axis_size++; - } - } - size_t reduce_axis_size = 0; - for (const auto axis : pop->reduce_axis) { - if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { - reduce_axis_size++; - } - } - - if (reduce_axis_size) { - for (const auto& iter : stage->iters) { - ExtractOriginalIterators(iter->name, &original_names); - if (original_names.size() == axis_size + reduce_axis_size) { - return iter; - } - } - } else { - for (size_t i = 0; i < stage->iters.size(); i++) { - ExtractOriginalIterators(stage->iters[i]->name, &original_names); - if (original_names.size() == axis_size + 1) { - return stage->iters[i-1]; - } - } - } - - LOG(FATAL) << "Cannot find the iterator."; - return stage->iters[0]; -} - -// Random sample states -inline void RandomSampleStates(const std::vector& in_states, std::mt19937* random_gen, - size_t out_size, std::vector* out_states) { - out_states->clear(); - for (size_t i = 0; i < out_size; i++) { - out_states->push_back(in_states[(*random_gen)() % in_states.size()]); - } -} - -// Random choose an index according to a prefix sum probability -inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { - std::uniform_real_distribution<> dis(0.0, 1.0); - double x = dis(*random_gen); - - CHECK(!prefix_sum_probs.empty()); - - return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - - prefix_sum_probs.begin(); -} - -// Print all states -inline void PrintAllStates(const std::vector& states) { - for (size_t i = 0; i < states.size(); ++i) { - std::cerr << i << std::endl; - std::cerr << states[i]; - std::cerr << "==============================================" << std::endl; - } -} - -// Get all split steps on spatial iterators for one stage -void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); - -// Apply multi-level tiling structure according to a string format, -// where "S" stands a space level, "R" stands for a reudciton level. -// For example, if the format is "SSRSRS", the we will -// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 -// For example, if apply "SSRSRS" to matrix multiplication, -// we have space iterators i and j, reduce iterator k. -// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 -State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, - std::vector* spatial_split_step_ids); - -// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep -State FollowTiling(const State& state, int stage_id, - const std::vector& split_step_ids, int n_split); - -// Randomly mutate the tile size of one SplitStep -State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, int max_innermost_split_factor); - -// Randomly mutate the value of one auto_unroll_max_step PragmaStep -State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, - const std::vector& auto_unroll_configs); - -// Randomly mutate the parallel degree of one stage. -State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, - const SearchTask& task, int verbose = 0); - -// Randomly mutate the computation location of one stage. -State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, - const SearchTask& task); - -// GA: Crossover two states -State CrossOverState(const State& p1, const State& p2); - -// Prune undefined states. -void PruneUndefined(std::vector* states); - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_SEARCH_POLICY_UTILS_H_ diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index d84c3c57dc86..939fca83f1fb 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -96,26 +96,6 @@ struct Handler > { } writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); - } else if (auto ps = data[i].as<::tvm::ansor::FollowSplitStepNode>()) { - writer->WriteArrayItem(std::string("FSP")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->src_step_id); - writer->WriteArrayItem(ps->n_split); - } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { - writer->WriteArrayItem(std::string("FFSP")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - - writer->WriteArraySeperator(); - writer->BeginArray(false); - for (int x : ps->src_step_ids) { - writer->WriteArrayItem(x); - } - writer->EndArray(); - - writer->WriteArrayItem(ps->level); - writer->WriteArrayItem(static_cast(ps->factor_or_nparts)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); @@ -126,52 +106,6 @@ struct Handler > { writer->WriteArrayItem(x); } writer->EndArray(); - } else if (auto ps = data[i].as<::tvm::ansor::AnnotationStepNode>()) { - writer->WriteArrayItem(std::string("AN")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(static_cast(ps->annotation)); - } else if (auto ps = data[i].as<::tvm::ansor::ComputeAtStepNode>()) { - writer->WriteArrayItem(std::string("CA")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->target_stage_id); - writer->WriteArrayItem(ps->target_iter_id); - } else if (auto ps = data[i].as<::tvm::ansor::ComputeRootStepNode>()) { - writer->WriteArrayItem(std::string("CR")); - writer->WriteArrayItem(ps->stage_id); - } else if (auto ps = data[i].as<::tvm::ansor::ComputeInlineStepNode>()) { - writer->WriteArrayItem(std::string("CI")); - writer->WriteArrayItem(ps->stage_id); - } else if (auto ps = data[i].as<::tvm::ansor::CacheReadStepNode>()) { - writer->WriteArrayItem(std::string("CHR")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->scope_name); - writer->WriteArrayItem(ps->reader_stage_ids); - } else if (auto ps = data[i].as<::tvm::ansor::CacheWriteStepNode>()) { - writer->WriteArrayItem(std::string("CHW")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->scope_name); - } else if (auto ps = data[i].as<::tvm::ansor::PragmaStepNode>()) { - writer->WriteArrayItem(std::string("PR")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->pragma_type); - } else if (auto ps = data[i].as<::tvm::ansor::RfactorStepNode>()) { - writer->WriteArrayItem(std::string("RF")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->factor_iter_id); - } else if (auto ps = data[i].as<::tvm::ansor::StorageAlignStepNode>()) { - writer->WriteArrayItem(std::string("SA")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->factor); - writer->WriteArrayItem(ps->offset); - } else if (auto ps = data[i].as<::tvm::ansor::TensorizeStepNode>()) { - writer->WriteArrayItem(std::string("TS")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->ti_func_name); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -183,10 +117,9 @@ struct Handler > { inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Step> * data) { std::vector int_list; - bool s, inner_to_outer, factor_or_nparts; + bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; - int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; - int level, factor_iter_id, factor, offset; + int stage_id, iter_id, extent; reader->BeginArray(); data->clear(); @@ -215,116 +148,12 @@ struct Handler > { stage_id, iter_id, extent, std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); - } else if (name == "FSP") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&src_step_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&n_split); - data->push_back(::tvm::ansor::FollowSplitStep( - stage_id, iter_id, src_step_id, n_split)); - } else if (name == "FFSP") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&int_list); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&level); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&factor_or_nparts); - data->push_back(::tvm::ansor::FollowFusedSplitStep( - stage_id, iter_id, int_list, level, factor_or_nparts)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); - } else if (name == "AN") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStep(stage_id, - iter_id, ::tvm::ansor::IteratorAnnotation(ann))); - } else if (name == "CA") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&target_stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - data->push_back(::tvm::ansor::ComputeAtStep( - stage_id, target_stage_id, iter_id)); - } else if (name == "CR") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeRootStep(stage_id)); - } else if (name == "CI") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeInlineStep(stage_id)); - } else if (name == "CHR") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&scope_name); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&int_list); - data->push_back(::tvm::ansor::CacheReadStep( - stage_id, scope_name, int_list)); - } else if (name == "CHW") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&scope_name); - data->push_back(::tvm::ansor::CacheWriteStep( - stage_id, scope_name)); - } else if (name == "PR") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&pragma_type); - data->push_back(::tvm::ansor::PragmaStep( - stage_id, iter_id, pragma_type)); - } else if (name == "RF") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStep( - stage_id, iter_id, factor_iter_id)); - } else if (name == "SA") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&factor); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&offset); - data->push_back(::tvm::ansor::StorageAlignStep( - stage_id, iter_id, factor, offset)); - } else if (name == "TS") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&ti_func_name); - data->push_back(::tvm::ansor::TensorizeStep( - stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index e882a0495263..1bcea3f690c9 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -183,107 +183,6 @@ std::string SplitStepNode::PrintAsPythonAPI( lengths, inner_to_outer); } -/********** Follow Split **********/ -FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, - int src_step_id, int n_split) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_id = src_step_id; - node->n_split = n_split; - data_ = std::move(node); -} - -void FollowSplitStepNode::ExtractSplitLengths( - const std::vector& transform_steps, - std::vector* lengths) const { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - - // get lengths from src step - lengths->reserve(n_split); - int j = 0; - for (; j < n_split - 1; ++j) { - lengths->push_back(ps->lengths[j]); - } - PrimExpr last_factor = 1; - for (; j < static_cast(ps->lengths.size()); ++j) { - if (ps->lengths[j].defined()) { - last_factor *= ps->lengths[j]; - } else { - last_factor = PrimExpr(); - break; - } - } - lengths->push_back(std::move(last_factor)); -} - -std::vector FollowSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -std::string FollowSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -/********** Follow Fused Split **********/ -FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_ids = src_step_ids;; - node->level = level; - node->factor_or_nparts = factor_or_nparts; - data_ = std::move(node); -} - -PrimExpr FollowFusedSplitStepNode::ExtractSplitLength( - const std::vector& transform_steps) const { - PrimExpr ret(1); - - for (int src_step_id : src_step_ids) { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - if (ps->lengths[level].defined() && ret.defined()) { - ret *= ps->lengths[level]; - } else { - return PrimExpr(); - } - } - - return ret; -} - -std::vector FollowFusedSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - -std::string FollowFusedSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - - /********** Fuse **********/ FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { auto node = make_object(); @@ -337,506 +236,5 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** Annotation **********/ -AnnotationStep::AnnotationStep(int stage_id, int iter_id, - IteratorAnnotation ann) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->annotation = ann; - data_ = std::move(node); -} - -void AnnotationStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - switch (annotation) { - case kUnroll: stage.unroll(axes[iter_id]); break; - case kVectorize: stage.vectorize(axes[iter_id]); break; - case kParallel: stage.parallel(axes[iter_id]); break; - case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; - case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; - case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; - case kThreadX: - if (axes[iter_id]->iter_type == kCommReduce) { - const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); - stage.bind(axes[iter_id], thread_x); - stage.set_store_predicate(thread_x->var == 0); - } else { - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); - } - break; - case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; - case kNone: break; - default: LOG(FATAL) << "Invalid Annotation " << annotation; break; - } -} - -std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& iter = (*stage_to_axes)[stage][iter_id]; - - bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; - if (bind_reduce_iter) { - ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; - } - - ss << "s[" << CleanName(stage->op->name) << "]."; - switch (annotation) { - case kUnroll: ss << "unroll("; break; - case kVectorize: ss << "vectorize("; break; - case kParallel: ss << "parallel("; break; - case kVThread: - case kBlockX: - case kBlockY: - case kThreadX: - case kThreadY: ss << "bind("; break; - case kNone: break; - default: - LOG(FATAL) << "Invalid annotation " << annotation; break; - } - ss << CleanName(iter->var->name_hint); - switch (annotation) { - case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; - case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; - case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; - case kThreadX: - if (bind_reduce_iter) { - ss << ", thread_x"; - } else { - ss << ", tvm.thread_axis(\"threadIdx.x\")"; - } - break; - case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; - default: break; - } - ss << ")\n"; - - if (bind_reduce_iter) { - ss << "s[" << CleanName(stage->op->name) << "]" - << ".set_store_predicate(thread_x.var.equal(0))\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute At **********/ -ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->target_stage_id = target_stage_id; - node->target_iter_id = target_iter_id; - data_ = std::move(node); -} - -void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const IterVar& target_axis = - (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; - stage.compute_at((*stages)[target_stage_id], target_axis); -} - -std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& target_stage = (*stages)[target_stage_id]; - - ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" - << CleanName(target_stage->op->name) << "], " - << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); - - ss << ")\n"; - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute Root **********/ -ComputeRootStep::ComputeRootStep(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - data_ = std::move(node); -} - -void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_root(); -} - -std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Compute Inline **********/ -ComputeInlineStep::ComputeInlineStep(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - data_ = std::move(node); -} - -void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_inline(); -} - -std::string ComputeInlineStepNode::PrintAsPythonAPI( - std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Cache Read **********/ -CacheReadStep::CacheReadStep(int stage_id, std::string scope_name, - const std::vector& reader_stage_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - node->reader_stage_ids = reader_stage_ids; - data_ = std::move(node); -} - -te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array readers; - for (const auto& i : reader_stage_ids) { - readers.push_back((*stages)[i]->origin_op); - } - auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); - - const auto& new_stage = (*schedule)[out->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id + 1, new_stage); - - return out; -} - -std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - auto stage = (*stages)[stage_id]; - std::vector reader_stages; - for (size_t i = 0; i < reader_stage_ids.size(); ++i) { - reader_stages.push_back((*stages)[reader_stage_ids[i]]); - } - - auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - - ss << CleanName(out->op->name) << " = " - << "s.cache_read(" << CleanName(stage->op->name) << ", \"" - << scope_name << "\", [" - << CleanName(reader_stages[0]->op->name); - for (size_t i = 1; i < reader_stage_ids.size(); ++i) { - ss << ", " << CleanName(reader_stages[i]->op->name); - } - ss << "])\n"; - - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->name) - << ".op.axis)\n"; - - return ss.str(); -} - -/********** Cache Write **********/ -CacheWriteStep::CacheWriteStep(int stage_id, std::string scope_name) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - data_ = std::move(node); -} - -Array CacheWriteStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array tensor_array; - // If the target stage has multi outputs, TVM requires to cache_write - // all of them or schedule.cache_write will raise an error - for (auto i = 0; i < stage->op->num_outputs(); ++i) { - tensor_array.push_back(stage->origin_op.output(i)); - } - auto outs = schedule->cache_write(tensor_array, scope_name); - - UpdateStageAxis(stage, stage_to_axes); - // Even if there is multi outputs, TVM schedule only generate one - // new stage - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - te::Stage stage = (*stages)[stage_id]; - - auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->name) << ", "; - } - ss << "= " << "s.cache_write([" - << CleanName(stage->op.output(0)->op->name); - for (auto i = 1; i < stage->op->num_outputs(); ++i) { - ss << ", " << CleanName(stage->op.output(i)->op->name); - } - ss << "], \"" << scope_name << "\")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->name) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->name) - << ".op.reduce_axis)\n"; - } - - return ss.str(); -} - -/********** Pragma **********/ -PragmaStep::PragmaStep(int stage_id, int iter_id, std::string pragma_type) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->pragma_type = std::move(pragma_type); - data_ = std::move(node); -} - -void PragmaStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - stage.pragma(axes[iter_id], "auto_unroll_max_step", value); - stage.pragma(axes[iter_id], "unroll_explicit", true); - } else { - stage.pragma(axes[iter_id], pragma_type); - } -} - -std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"unroll_explicit\", True)\n"; - } else { - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" - << pragma_type << "\")\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Rfactor **********/ -RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor_iter_id = factor_iter_id; - data_ = std::move(node); -} - -Array RfactorStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - const auto& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - const te::Tensor& tensor = stage->origin_op.output(0); - const IterVar& axis = axes[iter_id]; - auto outs = schedule->rfactor(tensor, axis, factor_iter_id); - - UpdateStageAxis(stage, stage_to_axes); - - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); - const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); - - const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->name); - if (i != outs.size() - 1) { - ss << ", "; - } - } - ss << " = " << "s.rfactor(" - << tensor_name << ", " - << axis_name << ", " - << factor_iter_id << ")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->name) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->name) - << ".op.reduce_axis)\n"; - } - - const auto& output = (*stages)[stage_id + 1]->op.output(0); - const auto& iters = output->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(s[" << CleanName(output->op->name) - << "].op.axis)" - << " + " << "tuple(s[" << CleanName(output->op->name) - << "].op.reduce_axis)\n"; - - return ss.str(); -} - -/********** Storage Align **********/ -StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, - int factor, int offset) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor = factor; - node->offset = offset; - data_ = std::move(node); -} - -void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - stage.storage_align(axes[iter_id], factor, offset); -} - -std::string StorageAlignStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].storage_align(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " - << factor << ", " << offset << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Tensorize **********/ -TensorizeStep::TensorizeStep(int stage_id, int iter_id, - std::string ti_func_name) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->ti_func_name = ti_func_name; - data_ = std::move(node); -} - -void TensorizeStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - auto func = tvm::runtime::Registry::Get(ti_func_name); - CHECK(func != nullptr) << "Cannot find the tensorize intrinsic func"; - tvm::te::TensorIntrin res = (*func)(); - CHECK(res.defined()) << "Tensorize intrinsic func must return a " - << "tvm::te::TensorIntrin object"; - stage.tensorize(axes[iter_id], res); -} - -std::string TensorizeStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].tensorize(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " - << ti_func_name << "())\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index f8283b876f18..8eff6a4e7536 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -114,80 +114,6 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/*! \brief Similar to SplitStepNode, but use split factor from another step - * (i.e. Follow another split step) */ -class FollowSplitStepNode: public StepNode { - public: - int iter_id; // The id of the iter to split - int src_step_id; // The index of the split step to follow in the history - int n_split; // The number of split level - - void ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowSplitStepNode. - * \sa FollowSplitStepNode - */ -class FollowSplitStep : public Step { - public: - FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); -}; - - -/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. - * \Note This can be used for the split in cooperative fetching - */ -class FollowFusedSplitStepNode: public StepNode { - public: - int iter_id; // The id of the iter to split - std::vector src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - - PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowFusedSplitStepNode. - * \sa FollowFusedSplitStepNode - */ -class FollowFusedSplitStep : public Step { - public: - FollowFusedSplitStep(int stage_id, int iter_id, - const std::vector& src_step_ids, - int level, bool factor_or_nparts); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); -}; - /*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: @@ -216,298 +142,6 @@ class FuseStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; -/*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. - * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) - */ -class AnnotationStepNode: public StepNode { - public: - int iter_id; - IteratorAnnotation annotation; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); -}; - -/*! - * \brief Managed reference to AnnotationStepNode. - * \sa AnnotationStepNode - */ -class AnnotationStep : public Step { - public: - AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); - - TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); -}; - -/*! \brief Compute at step that corresponds to te::Stage::compute_at */ -class ComputeAtStepNode: public StepNode { - public: - int target_stage_id; - int target_iter_id; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); -}; - -/*! - * \brief Managed reference to ComputeAtStepNode. - * \sa ComputeAtStepNode - */ -class ComputeAtStep : public Step { - public: - ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); -}; - -/*! \brief Compute root step that corresponds to te::Stage::compute_root */ -class ComputeRootStepNode: public StepNode { - public: - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); -}; - -/*! - * \brief Managed reference to ComputeRootStepNode. - * \sa ComputeRootStepNode - */ -class ComputeRootStep : public Step { - public: - explicit ComputeRootStep(int stage_id); - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); -}; - -/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ -class ComputeInlineStepNode: public StepNode { - public: - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); -}; - -/*! - * \brief Managed reference to ComputeInlineStepNode. - * \sa ComputeInlineStepNode - */ -class ComputeInlineStep : public Step { - public: - explicit ComputeInlineStep(int stage_id); - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); -}; - -/*! \brief Cache read step that corresponds to te::Schedule::cache_read */ -class CacheReadStepNode: public StepNode { - public: - std::string scope_name; - std::vector reader_stage_ids; - - te::Tensor ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheReadStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); -}; - -/*! - * \brief Managed reference to CacheReadStepNode. - * \sa CacheReadStepNode - */ -class CacheReadStep : public Step { - public: - CacheReadStep(int stage_id, std::string scope_name, - const std::vector& reader_stage_id); - - TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); -}; - -/*! \brief Cache write step that corresponds to te::Schedule::cache_write - * \Note This step will cache_write all output tensors of target stage */ -class CacheWriteStepNode: public StepNode { - public: - std::string scope_name; - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheWriteStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); -}; - -/*! - * \brief Managed reference to CacheWriteStepNode. - * \sa CacheWriteStepNode - */ -class CacheWriteStep : public Step { - public: - CacheWriteStep(int stage_id, std::string scope_name); - - TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); -}; - -/*! \brief Pragma step that corresponds to te::Schedule::pragma */ -class PragmaStepNode: public StepNode { - public: - int iter_id; - std::string pragma_type; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PragmaStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); -}; - -/*! - * \brief Managed reference to PragmaStepNode. - * \sa PragmaStepNode - */ -class PragmaStep : public Step { - public: - PragmaStep(int stage_id, int iter_id, std::string pragma_type); - - TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); -}; - -/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ -class RfactorStepNode: public StepNode { - public: - int iter_id; - int factor_iter_id; - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.RfactorStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); -}; - -/*! - * \brief Managed reference to RfactorStepNode. - * \sa RfactorStepNode - */ -class RfactorStep : public Step { - public: - RfactorStep(int stage_id, int iter_id, int factor_iter_id); - - TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); -}; - -/*! \brief Storage align step that corresponds to te::Schedule::storage_align */ -class StorageAlignStepNode: public StepNode { - public: - int iter_id; - int factor; - int offset; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.StorageAlignStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); -}; - -/*! - * \brief Managed reference to StorageAlignStepNode. - * \sa StorageAlignStepNode - */ -class StorageAlignStep : public Step { - public: - StorageAlignStep(int stage_id, int iter_id, int factor, int offset); - - TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); -}; - -/*! \brief Tensorize step that corresponds to te::Schedule::tensorize - * \Note This step takes a global registered function name as input. */ -class TensorizeStepNode: public StepNode { - public: - int iter_id; - std::string ti_func_name; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.TensorizeStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); -}; - -/*! - * \brief Managed reference to TensorizeStepNode. - * \sa TensorizeStepNode - */ -class TensorizeStep : public Step { - public: - TensorizeStep(int stage_id, int iter_id, std::string ti_func_name); - - TVM_DEFINE_OBJECT_REF_METHODS(TensorizeStep, Step, TensorizeStepNode); -}; - } // namespace ansor } // namespace tvm @@ -536,69 +170,10 @@ struct hash<::tvm::ansor::Step> { } } return ret; - } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { - return ::dmlc::HashCombine(3, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->src_step_id), - ps->n_split)))); - } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { - return ::dmlc::HashCombine(4, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), - ::dmlc::HashCombine(std::hash()(ps->level), - ps->factor_or_nparts))))); } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine(5, + return ::dmlc::HashCombine(3, ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->fused_ids)); - } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { - return ::dmlc::HashCombine(6, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - static_cast(ps->annotation)))); - } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { - return ::dmlc::HashCombine(7, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->target_stage_id), - ps->target_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { - return ::dmlc::HashCombine(8, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { - return ::dmlc::HashCombine(9, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { - return ::dmlc::HashCombine(10, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->scope_name), - ps->reader_stage_ids))); - } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { - return ::dmlc::HashCombine(11, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->scope_name)); - } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { - return ::dmlc::HashCombine(12, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->pragma_type))); - } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { - return ::dmlc::HashCombine(13, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->factor_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { - return ::dmlc::HashCombine(14, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->factor), - ps->offset)))); - } else if (auto ps = step.as<::tvm::ansor::TensorizeStepNode>()) { - return ::dmlc::HashCombine(15, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->ti_func_name))); } else { LOG(FATAL) << "Invalid step"; } diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 485679d6aa4e..62ebeb99a6c8 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -60,15 +60,10 @@ def get_tiled_matmul(): dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C_global = s0.cache_write(C, "global") its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3]]) - s0.compute_at(C_global, C, s0[C].iters[3]) - s0.split(C_global, s0[C_global].iters[2], [16]) - B_global = s0.cache_read(B, "global", [C_global]) - s0.compute_at(B_global, C_global, s0[C_global].iters[0]) - A_global = s0.cache_read(A, "global", [C_global]) - s0.compute_at(A_global, C_global, s0[C_global].iters[2]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3], + s0[C].iters[8]]) + return dag, s0 diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 0768f82b805a..e5af07b31e0d 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,15 +34,6 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global = s.stage_ops[1] - B_global = s.stage_ops[3] - C_global = s.stage_ops[4] - assert s[B_global].iters[0].range.extent == 512 - assert s[B_global].iters[1].range.extent == 16 - assert s[A_global].iters[0].range.extent == 1 - assert s[A_global].iters[1].range.extent == 16 - assert s[C_global].iters[0].range.extent == 64 - def test_estimate_flop(): dag, s = get_tiled_matmul() @@ -50,25 +41,7 @@ def test_estimate_flop(): assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 -def test_lower_legalize_invalid_attach(): - N, M = 10, 10 - - A = te.compute((N, M), lambda i, j: 1.0, name='A') - B = te.compute((N, M), lambda i, j: A[i][j], name='B') - - dag = ansor.ComputeDAG([A, B]) - s = dag.get_init_state() - - s.compute_at(A, B, s[B].iters[1]) - s.split(B, s[B].iters[1], [2]) - - sch, tensors = dag.apply_steps_from_state(s) - stmt = tvm.lower(sch, tensors, simple_mode=True) - - if __name__ == "__main__": test_apply_steps() test_infer_bound() test_estimate_flop() - test_lower_legalize_invalid_attach() - diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 984434b9c58b..b701dad6d8c0 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -28,7 +28,7 @@ from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', - cost_model=ansor.RandomModel(), n_trials=2, params=None, + cost_model=None, n_trials=2, params=None, pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) @@ -42,7 +42,8 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) + search_policy = ansor.EmptyPolicy() + # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) From 4042cfabe5f3076ca672547547787f64f6bcd51d Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 28 Jun 2020 10:55:55 +0800 Subject: [PATCH 42/78] Bug fix & Delete AccessAnalyzer --- python/tvm/ansor/compute_dag.py | 15 + python/tvm/ansor/loop_state.py | 19 - python/tvm/ansor/measure.py | 1 + python/tvm/ansor/utils.py | 22 - src/ansor/compute_dag.cc | 433 +----------------- src/ansor/compute_dag.h | 67 --- src/ansor/measure.cc | 24 +- src/ansor/serialization.cc | 27 -- src/ansor/serialization.h | 5 - src/ansor/utils.cc | 66 --- src/ansor/utils.h | 115 ----- .../python/unittest/test_ansor_compute_dag.py | 1 - 12 files changed, 29 insertions(+), 766 deletions(-) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index d591d615d1c5..e57fbbc08843 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -75,3 +75,18 @@ def print_python_code_from_state(self, state): """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) + + def infer_bound_from_state(self, state): + """ + Infer bound for a state + + Parameters + ---------- + state : StateObject + + Returns + ------- + state : State + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index bf81311ed664..791ba2e74ad4 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -188,25 +188,6 @@ def _update_stage_id_map(self): for index, stage in enumerate(self.stages_cache): self.stage_id_map[stage.op] = index - def _insert_new_stage(self, new_stage_id): - new_stage_id = int(new_stage_id) - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - added_op = self.stages_cache[new_stage_id].op - - # Add a new stage will change all ops. But we still want to use the old ops to index stages, - # So we keep updating them and do not remove the old ops. - - # Update stage_id_map for old ops, so we can still use the old ops to index stages. - for key, value in self.stage_id_map.items(): - if value >= new_stage_id: - self.stage_id_map[key] = value + 1 - self.stage_id_map[added_op] = new_stage_id - - # Update stage_id_map for new ops - self._update_stage_id_map() - - return added_op - def _clear_cache(self): self.stages_cache = None diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index af0eddc59653..66d1eb74fac9 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -35,6 +35,7 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform +from tvm.contrib import tar from . import _ffi_api from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index b406824ba811..dd701234aef9 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -90,28 +90,6 @@ def get_const_tuple(in_tuple): return tuple(get_const_int(x) for x in in_tuple) -def to_str_round(x, decimal=6): - """Convert object to str and round float numbers""" - if isinstance(x, str): - return x - if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): - return "[" + ", ".join([to_str_round(y, decimal=decimal) - for y in x]) + "]" - if isinstance(x, dict): - return str({k: eval(to_str_round(v)) for k, v in x.items()}) - if isinstance(x, int): - return str(x) - if isinstance(x, (np.float32, np.float64, float)): - format_str = "%%.%df" % decimal - return format_str % x - raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x))) - - -def array_mean(arr): - """Mean function for tvm array (Array)""" - return sum(x.value for x in arr) / len(arr) - - class NoDaemonProcess(multiprocessing.Process): @property def daemon(self): diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 7638f98e65ea..6b0d8d5fcc4b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -45,11 +45,6 @@ using namespace tvm::tir; TVM_REGISTER_NODE_TYPE(ComputeDAGNode); -template -using OperationMap = AccessAnalyzerNode::OperationMap; - -using OperationSet = std::unordered_set; - // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops void TopoSortOps(const Array& tensors, @@ -120,357 +115,6 @@ void TopoSortOps(const Array& tensors, } } -// Extract all tensor accesses in an expr -class TensorAccessExtractor : public StmtExprVisitor { - public: - void Extract(PrimExpr expr) { - this->VisitExpr(expr); - } - - void VisitExpr_(const CallNode* op) final { - if (op->name == tir::intrinsic::tvm_if_then_else) { - has_branch = true; - } - StmtExprVisitor::VisitExpr_(op); - } - - void VisitExpr_(const ProducerLoadNode* op) final { - buf_accesses[Downcast(op->producer)->op].emplace_back( - op->indices.begin(), op->indices.end()); - StmtExprVisitor::VisitExpr_(op); - } - - void VisitStmt_(const IfThenElseNode* op) final { - has_branch = true; - StmtExprVisitor::VisitStmt_(op); - } - - void VisitExpr_(const SelectNode* op) final { - has_branch = true; - StmtExprVisitor::VisitExpr_(op); - } - - OperationMap > > buf_accesses; - bool has_branch{false}; -}; - -// Returns whether the expr equals to the var with a const shift -bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { - if (auto pv = expr.as()) { - return pv == var.get(); - } else if (auto padd = expr.as()) { - return ((padd->a.get() == var.get() && padd->b->IsInstance()) || - (padd->b.get() == var.get() && padd->a->IsInstance())); - } else if (auto psub = expr.as()) { - return ((psub->a.get() == var.get() && psub->b->IsInstance()) || - (psub->b.get() == var.get() && psub->a->IsInstance())); - } else { - return false; - } -} - -// Return whether the access is injective -bool IsInjective(const te::Operation& op, const std::vector& index, - bool* axis_missing, bool* axis_duplicated, bool* same_order) { - auto cop = op.as(); - if (cop == nullptr) { return false; } - - std::vector index_to_var_idx; - std::vector var_idx_ct(cop->axis.size(), 0); - - for (const auto& expr : index) { - if (!is_const(expr)) { - bool found = false; - for (size_t i = 0; i < cop->axis.size(); ++i) { - if (IsConstShiftEqual(cop->axis[i]->var, expr)) { - index_to_var_idx.push_back(i); - var_idx_ct[i]++; - found = true; - break; - } - } - if (!found) { - return false; - } - } - } - - *axis_missing = false; // Some axes are missing - *axis_duplicated = false; // Some axes appear more than once - *same_order = true; // The axis order is the same as op->axis - for (int ct : var_idx_ct) { - if (ct == 0) { - *axis_missing = true; - } else if (ct > 1) { - *axis_duplicated = true; - } - } - for (size_t i = 1; i < index_to_var_idx.size(); ++i) { - if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { - *same_order = false; - break; - } - } - - return true; -} - -// Gather all VarNodes in an expr -static void GatherVars(const PrimExpr& expr, - std::unordered_set* vars) { - PostOrderVisit(expr, [&vars](const ObjectRef &node) { - if (const VarNode* op = node.as()) { - vars->insert(op); - } - }); -} - -// Check whether an expr has expensive operations (e.g. exp) -static bool HasExpensiveOp(const PrimExpr& expr) { - bool found = false; - PostOrderVisit(expr, [&found](const ObjectRef &node) { - if (const CallNode* op = node.as()) { - if (op->call_type == CallNode::CallType::PureIntrinsic && - op->name == "exp") { - found = true; - } - } - }); - return found; -} - -AccessAnalyzer::AccessAnalyzer(const Array& tensors) { - auto node = make_object(); - OperationMap has_branch; - - // get all ops - TopoSortOps(tensors, &node->ops_topo_order); - - // build read & write access map - for (const auto& op : node->ops_topo_order) { - if (op->IsInstance()) { - node->read_from[op] = - OperationMap > >(); - } else if (auto cop = op.as()) { - TensorAccessExtractor extractor; - for (const auto& exp : cop->body) { - extractor.Extract(exp); - } - - for (const auto& iter : extractor.buf_accesses) { - std::vector >& accesses = - node->read_by[iter.first][op]; - accesses.insert(accesses.begin(), iter.second.begin(), - iter.second.end()); - } - - node->read_from[op] = std::move(extractor.buf_accesses); - has_branch[op] = extractor.has_branch; - } else { - LOG(FATAL) << "Invalid op: " << op; - } - } - - // do some static analysis - for (const auto& op : node->ops_topo_order) { - if (op->IsInstance()) { - node->is_injective[op] = true; - node->needs_multi_level_tiling[op] = false; - node->is_strict_inlineable[op] = false; - node->is_output[op] = false; - } else if (auto pop = op.as()) { - // check whether is element-wise and strict-inlineable - // (see definition in compute_dag.h) - bool is_injective = true; - bool is_strict_inlineable = true; - - bool axis_missing, axis_duplicated, same_order; - for (const auto& pair : node->read_from[op]) { - const std::vector >& access = pair.second; - for (const auto& index : access) { - if (!ansor::IsInjective(op, index, &axis_missing, &axis_duplicated, - &same_order)) { - is_injective = false; - is_strict_inlineable = false; - break; - } - if (!same_order || axis_duplicated) { - // do not strictly inline transpose - is_strict_inlineable = false; - } - } - if (!is_injective) { break; } - } - if (has_branch[op]) { - is_strict_inlineable = false; - } - - // don't strictly inline expensive op (e.g. exp) - bool has_expensive_op = false; - for (const auto& expr : pop->body) { - has_expensive_op |= HasExpensiveOp(expr); - } - - node->is_injective[op] = is_injective; - node->is_strict_inlineable[op] = is_strict_inlineable && - !has_expensive_op; - - // check whether the op needs multi-level tiling - // (see definition in compute_dag.h) - bool needs_multi_level_tiling = false; - int n_missing = 0; - - for (const auto& pair : node->read_from[op]) { - const std::vector > &access = pair.second; - std::unordered_set vars; - for (const std::vector &indices : access) { - for (const PrimExpr& expr : indices) { - GatherVars(expr, &vars); - } - } - bool missing = false; - for (const auto& axis : pop->axis) { - if (GetIntImm(axis->dom->extent) > 1 && - vars.count(axis->var.get()) == 0) { - missing = true; - } - } - if (missing) { - n_missing++; - } - - if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { - needs_multi_level_tiling = true; - break; - } - } - - node->needs_multi_level_tiling[op] = needs_multi_level_tiling; - - // check whether is output - node->is_output[op] = node->read_by[op].empty(); - } else { - LOG(FATAL) << "Invalid op" << op; - } - } - - data_ = std::move(node); -} - -bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation &op) const { - return operator->()->needs_multi_level_tiling.at(op); -} - -bool AccessAnalyzer::IsOutput(const te::Operation& op) const { - return operator->()->is_output.at(op); -} - -bool AccessAnalyzer::IsInjective(const te::Operation& op) const { - return operator->()->is_injective.at(op); -} - -bool AccessAnalyzer::IsStrictInlineable(const te::Operation &op) const { - return operator->()->is_strict_inlineable.at(op); -} - -void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, - OperationSet* producers) const { - producers->clear(); - for (const auto& iter : operator->()->read_from.at(op)) { - producers->insert(iter.first); - } -} - -void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, - OperationSet* consumers) const { - OperationSet inlined_ops; - - for (const auto& stage : state->stages) { - if (stage->compute_at == kInlined) { - inlined_ops.insert(stage->op); - } - } - std::function collect; - - collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { - for (const auto& iter : operator->()->read_by.at(op)) { - if (inlined_ops.count(iter.first)) { - collect(iter.first); - } else { - consumers->insert(iter.first); - } - } - }; - - consumers->clear(); - collect(op); -} - -// Return whether two int arrays are elementwise-equal -bool IntArrayEqual(const Array& arr1, const Array& arr2) { - if (arr1.size() != arr2.size()) { - return false; - } - - for (size_t i = 0; i < arr1.size(); ++i) { - auto int1 = arr1[i].as(); - auto int2 = arr2[i].as(); - CHECK(int1 != nullptr); - CHECK(int2 != nullptr); - if (int1->value != int2->value) { - return false; - } - } - return true; -} - -bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, - const te::Operation& target_op) const { - te::Operation cur_op = op; - while (cur_op != target_op) { - const AccessAnalyzerNode::OperationMap > >& map = - operator->()->read_by.at(cur_op); - - if (map.size() != 1) { - return false; - } - te::Operation next_op = map.begin()->first; - - // Check condition 1: has the same output size - auto p_cur = cur_op.as(); - auto p_next = next_op.as(); - if (p_cur == nullptr || p_next == nullptr) { - return false; - } - - Array output_shape = p_cur->output_shape(0); - for (int i = 1; i < p_cur->num_outputs(); ++i) { - if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { - return false; - } - } - for (int i = 0; i < p_next->num_outputs(); ++i) { - if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { - return false; - } - } - - // Check condition 2: read is elementwise - const std::vector > reads = map.begin()->second; - bool is_injective, axis_missing, axis_duplicated, same_order; - for (const auto& read : reads) { - is_injective = ::tvm::ansor::IsInjective( - next_op, read, &axis_missing, &axis_duplicated, &same_order); - if (!is_injective || axis_missing || axis_duplicated || !same_order) { - return false; - } - } - - cur_op = std::move(next_op); - } - return true; -} - // Estimate number of float operations in an expression class FlopEstimator: public ExprFunctor { public: @@ -568,8 +212,9 @@ ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; node->tensors = std::move(tensors); - node->access_analyzer = AccessAnalyzer(node->tensors); - node->ops = Array(node->access_analyzer->ops_topo_order); + std::vector ops; + TopoSortOps(node->tensors, &ops); + node->ops = Array(ops); node->flop_ct = estimator.EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); @@ -587,8 +232,9 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { auto node = make_object(); FlopEstimator estimator; node->tensors = std::move(tens); - node->access_analyzer = AccessAnalyzer(node->tensors); - node->ops = Array(node->access_analyzer->ops_topo_order); + std::vector ops; + TopoSortOps(node->tensors, &ops); + node->ops = Array(ops); node->flop_ct = estimator.EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); @@ -708,30 +354,6 @@ void ComputeDAG::InferBound(std::vector* states) const { *states = std::move(out_states); } -void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, - ComputeDAG *task_dag) const { - std::vector stages; - StageToAxesMap stage_to_axes; - te::Schedule sch; - Array old_tensors; - - std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, - &stage_to_axes); - - Array new_tensors; - for (auto stage : sch->stages) { - if (stage->op->IsInstance() || - stage->is_output) { - for (auto i = 0; i < stage->op->num_outputs(); ++i) { - new_tensors.push_back(stage->op.output(i)); - } - } - } - - *task_dag = ComputeDAG(new_tensors); -} - - void ComputeDAG::InferBoundCommon(StateNode* pstate) const { std::vector stages; StageToAxesMap stage_to_axes; @@ -871,49 +493,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ss.str(); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - for (const auto& op : node->ops_topo_order) { - p->stream << op << std::endl; - p->stream << "is_injective:\t" << node->is_injective.at(op) << "\t\t"; - p->stream << "needs_multi_level_tiling:\t" - << node->needs_multi_level_tiling.at(op) << std::endl; - p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) - << "\t"; - p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; - p->stream << "Read from:\t"; - for (const auto& pair : node->read_from.at(op)) { - for (const auto& index : pair.second) { - p->stream << pair.first->name << Array(index) << ", "; - } - } - p->stream << "\n"; - p->stream << "Read by:\t"; - for (const auto& pair : node->read_by.at(op)) { - for (const auto& index : pair.second) { - p->stream << pair.first->name << Array(index) << ", "; - } - } - p->stream << "\n"; - p->stream << "==================================================\n"; - } - - AccessAnalyzer ana = GetRef(node); - - p->stream << "ElementwiseMatch: \n"; - for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { - for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { - if (i == j) { continue; } - if (ana.ElementWiseMatch(node->ops_topo_order[i], - node->ops_topo_order[j])) { - p->stream << node->ops_topo_order[i]->name << " -> " - << node->ops_topo_order[j]->name << "\n"; - } - } - } -}); - TVM_REGISTER_GLOBAL("ansor.ComputeDAG") .set_body_typed([](Array tensors) { return ComputeDAG(tensors); diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 2f1330d612dd..d9520e388ae0 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -39,66 +39,6 @@ namespace ansor { class StateNode; class State; class Step; -/*! \brief Read/Write access static analysis result */ -class AccessAnalyzerNode : public Object { - public: - template - using OperationMap = std::unordered_map; - - OperationMap > > > read_from; - OperationMap > > > read_by; - OperationMap is_injective; - OperationMap is_strict_inlineable; - OperationMap needs_multi_level_tiling; - OperationMap is_output; - std::vector ops_topo_order; - - static constexpr const char* _type_key = "ansor.AccessAnalyzer"; - TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); -}; - -/*! - * \brief Managed reference to AccessAnalyzerNode. - * \sa AccessAnalyzerNode - */ -class AccessAnalyzer : public ObjectRef { - public: - explicit AccessAnalyzer(const Array& tensors); - // read/write access analysis - bool NeedsMultiLevelTiling(const te::Operation& op) const; - bool IsInjective(const te::Operation& op) const; - bool IsStrictInlineable(const te::Operation& op) const; - bool IsOutput(const te::Operation& op) const; - - // Get all producers of an op - void GetProducers(const State& state, const te::Operation& op, - std::unordered_set* producers) const; - - // Get all consumers of an op. This func deals with inlined op correctly. - void GetConsumers(const State& state, const te::Operation& op, - std::unordered_set* consumers) const; - - // Check whether two ops are elementwise matched - // (e.g. conv2d and relu are elementwise matched) - bool ElementWiseMatch(const te::Operation& op, - const te::Operation& target_op) const; - - /*! \Note The current implementation follows these (rough) definitions. - * - * Definition of data-reuse : Exists axis in (op->axis union op->reduce_axis) - * and acc in read accesses, such that axis not in acc. - * (e.g. A[i][j] = B[i] has data reuse, while A[i][j] = B[i][j] does not) - * Definition of NeedsMultiLevelTiling: Exists two acc, both of them make this op have data reuse. - * Definition of injective : For all index expressions, they are single axis variable - * plus an optional const shift. - * (e.g. A[i][j] = B[i][j], A[i][j] = B[i+1][j] are injective, while A[i][j] = B[i*j] is not) - * Definition of strict-inlineable : All read accesses are elementwise, and no branch in the body - * (e.g. A[i][j] = B[i][j] + C[i][j] is strict-inlineable, - * while A[i][j] = tvm_if_then_else(B[i][j] > 0, C[i][j], 0) is not - */ - TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); -}; - typedef std::unordered_map, ObjectHash, ObjectEqual> StageToAxesMap; @@ -112,14 +52,12 @@ class ComputeDAGNode : public Object { Array tensors; // Input and output tensors Array ops; // All related operations in topo order double flop_ct; // Number of float operations - AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer ObjectRef init_state; // The initial state void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tensors", &tensors); v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); - v->Visit("access_analyzer", &access_analyzer); } static constexpr const char* _type_key = "ansor.ComputeDAG"; @@ -161,14 +99,9 @@ class ComputeDAG: public ObjectRef { // Return the new states inplace void InferBound(std::vector* states) const; - // Replay the transform steps and get the new DAG - void ReplayAndGetDAG(const std::vector& steps, ComputeDAG* task_dag) const; - // Get the init state State GetInitState() const; - static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders"; - TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index c50191813b2e..7802d571d702 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -218,15 +218,13 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, } ct++; - if (verbose >= 1) { - std::cout << std::fixed << std::setprecision(2); - std::cout << "===============================================\n"; - std::cout << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " - << best_flops[workload_key] / 1e9 - << "\tresults: " << result_batch[j] << "\n"; - std::cout << "===============================================\n"; - std::cout << input_batch[j]->state << "\n"; - } + StdCout(verbose) << std::fixed << std::setprecision(2) + << "===============================================\n" + << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 + << "\tresults: " << result_batch[j] << "\n" + << "===============================================\n" + << input_batch[j]->state << "\n"; } // Call callback functions @@ -345,13 +343,5 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner") return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); }); -TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") -.set_body_typed([](Builder builder, Runner runner, - Array callbacks, int verbose, - int max_continous_error = -1) { - return ProgramMeasurer(builder, runner, callbacks, verbose, - max_continous_error); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 939fca83f1fb..d4e4d645d192 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -385,33 +385,6 @@ std::pair, Array > LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -std::pair BestMeasurePairInFile( - const std::string& filename, const std::string& workload_key, - const Target& target) { - std::pair best_pair; - double best_cost = 1e30; - - auto inp = make_object(); - auto res = make_object(); - LogReader reader = LogReader(filename); - - while (reader->ReadNext(inp.get(), res.get())) { - if (res->error_no != kNoError || inp->task->workload_key != workload_key - || inp->task->target->target_name != target->target_name) { - continue; - } - - double cost = FloatArrayMean(res->costs); - - if (cost < best_cost) { - best_cost = cost; - best_pair = std::make_pair(inp->copy(), res->copy()); - } - } - - return best_pair; -} - TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const std::string& filename) { return LogToFile(filename); }); diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index 82dd036991e6..a1559c62062b 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -105,11 +105,6 @@ void ReadMeasureRecord(const std::string& str, MeasureResultNode* res, std::string* log_version); -/*! \brief Return the best measure pair with lowest cost in a file */ -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc index 27aac7e8b315..ed41321c4639 100644 --- a/src/ansor/utils.cc +++ b/src/ansor/utils.cc @@ -23,7 +23,6 @@ */ #include "utils.h" -#include namespace tvm { namespace ansor { @@ -33,56 +32,6 @@ NullStream& NullStream::Global() { return stream; } -const std::vector >& SplitFactorizationMemo::GetFactorizationSchemes( - int extent, int n_lengths, int max_innermost_factor) { - QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); - auto it = memory_.find(key); - if (it != memory_.end()) { - return it->second; - } - - tmp_stack_.assign(n_lengths, PrimExpr()); - results_ = &memory_[key]; - n_lengths_ = n_lengths; - - DfsEnumerate(0, extent, max_innermost_factor); - - return *results_; -} - -void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor) { - if (now == n_lengths_) { - if (tmp_stack_.back().as()->value <= max_innermost_factor) { - results_->push_back(tmp_stack_); - } - } else { - for (const auto& f : GetFactors(remaining_lenght)) { - tmp_stack_[now] = PrimExpr(f); - DfsEnumerate(now + 1, remaining_lenght / f, max_innermost_factor); - } - } -} - -const std::vector& SplitFactorizationMemo::GetFactors(int n) { - auto it = factor_memory_.find(n); - if (it != factor_memory_.end()) { - return it->second; - } - - std::vector& res = factor_memory_[n]; - int step = n % 2 == 0 ? 1 : 2; - for (size_t i = 1; i < static_cast(std::sqrt(n)) + 1; i += step) { - if (n % i == 0) { - res.push_back(i); - if (n / i != i) { - res.push_back(n/i); - } - } - } - std::sort(res.begin(), res.end()); - return res; -} - ThreadPool& ThreadPool::Global() { static ThreadPool* pool = new ThreadPool(); static int ct = 0; @@ -102,20 +51,5 @@ ThreadPool& ThreadPool::Global() { return *pool; } -TVM_REGISTER_GLOBAL("ansor.utils.GetFactorizationSchemes") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int extent = args[0]; - int n_lengths = args[1]; - int max_innermost_factor = args[2]; - SplitFactorizationMemo memo; - - Array > result; - for (const auto& lens : memo.GetFactorizationSchemes(extent, n_lengths, max_innermost_factor)) { - result.push_back(lens); - } - - *ret = result; -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 4e98bb907af9..a0a00ef947cd 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -125,98 +125,6 @@ inline void DeleteItem(std::vector* array, const T& to_delete) { } } -/*! \brief Compute the product of all elements in a vector */ -inline int64_t ElementProduct(const std::vector& array) { - int64_t ret = 1; - for (auto x : array) { - ret *= x; - } - return ret; -} - -/*! \brief Get the maximum element in a vector */ -template -T MaximumElement(const std::vector& array) { - CHECK(!array.empty()); - const T* pmax = &array[0]; - for (size_t i = 1; i < array.size(); ++i) { - if (array[i] > *pmax) { - pmax = &array[i]; - } - } - return *pmax; -} - -/*! \brief Move elements from multiple vectors to one vector */ -template -std::vector& ConcatenateMove(std::vector* out, std::vector* in) { - out->insert(out->end(), std::make_move_iterator(in->begin()), - std::make_move_iterator(in->end())); - return *out; -} - -/*! \brief Move elements from multiple vectors to one vector */ -template -std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args... args) { - ConcatenateMove(out, first); - ConcatenateMove(out, args...); - return *out; -} - -/*! \brief Get a random permutation of integers [0, n-1] */ -template -void RandomPermutation(int n, std::vector* out, G* gen) { - out->assign(n, 0); - std::iota(out->begin(), out->end(), 0); - std::shuffle(out->begin(), out->end(), *gen); -} - -/*! \brief Random sample without replacement */ -template -void RandomSample(std::vector* in_data, size_t out_size, G* gen) { - // Note: This function is inefficient in the cases when out_size << in_data.size() - out_size = std::min(in_data->size(), out_size); - - if (in_data->size() <= out_size) { // return all - return; - } - std::vector indices; - RandomPermutation(in_data->size(), &indices, gen); - - std::vector tmp_data; - tmp_data.reserve(out_size); - for (size_t i = 0; i < out_size; ++i) { - tmp_data.push_back(std::move((*in_data)[indices[i]])); - } - - *in_data = std::move(tmp_data); -} - -/*! \brief Argsort. Order: largest to smallest */ -template -inline void Argsort(const std::vector& scores, std::vector* index) { - index->clear(); index->reserve(scores.size()); - for (size_t i = 0; i < scores.size(); ++i) { - index->push_back(i); - } - auto cmp = [&scores](int l, int r) { - return scores[l] > scores[r]; - }; - std::sort(index->begin(), index->end(), cmp); -} - -/*! \brief Return whether a string ends with another substring */ -inline bool StrEndsWith(const std::string& a, const std::string& b) { - if (b.size() > a.size()) return false; - return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); -} - -/*! \brief Return whether a string starts with another substring */ -inline bool StrStartsWith(const std::string& a, const std::string& b) { - if (b.size() > a.size()) return false; - return std::equal(a.begin(), a.begin() + b.size(), b.begin()); -} - /*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); @@ -399,29 +307,6 @@ class ThreadPool { std::condition_variable finish_signal_; }; -/*! - * \brief Enumerate all possible factorization schemes for splitting an axes. - * \note This class will memorize the results for reuse. - */ -class SplitFactorizationMemo { - public: - using QueryKey = std::tuple; - - const std::vector >& GetFactorizationSchemes( - int extent, int n_lengths, int max_innermost_factor); - const std::vector& GetFactors(int n); - - private: - void DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor); - - std::unordered_map > > memory_; - - int n_lengths_; - std::vector tmp_stack_; - std::vector >* results_; - std::unordered_map> factor_memory_; -}; - } // namespace ansor } // namespace tvm diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index e5af07b31e0d..934c13f158ef 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -37,7 +37,6 @@ def test_infer_bound(): def test_estimate_flop(): dag, s = get_tiled_matmul() - assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 From 7695defec07d6f41cd1e76c35c5337b23a6e39e1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 28 Jun 2020 11:48:39 +0800 Subject: [PATCH 43/78] Delete attachmap & Code clean --- python/tvm/rpc/server.py | 3 +- src/ansor/compute_dag.cc | 9 +-- src/ansor/compute_dag.h | 10 +-- src/ansor/loop_state.cc | 139 ------------------------------------ src/ansor/loop_state.h | 59 +++------------- tests/cpp/ansor_test.cc | 148 --------------------------------------- 6 files changed, 13 insertions(+), 355 deletions(-) delete mode 100644 tests/cpp/ansor_test.cc diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 42bcb00a9117..15a3c7de789d 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -348,8 +348,7 @@ def __init__(self, cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, - "--port=%s" % port, - "--port-end=%s" % port_end] + "--port=%s" % port] if tracker_addr: assert key cmd += ["--tracker=%s:%d" % tracker_addr, diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 6b0d8d5fcc4b..fc37a09872ea 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -262,8 +262,7 @@ void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { } std::pair > ComputeDAG::ApplySteps( - const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level) const { + const std::vector& transform_steps) const { std::vector stages; StageToAxesMap stage_to_axes; return ReplaySteps(transform_steps, &stages, &stage_to_axes); @@ -505,14 +504,10 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body([](TVMArgs args, TVMRetValue *ret) { ComputeDAG dag = args[0]; State state = args[1]; - LayoutRewriteLevel layout_rewrite_level = kNoRewrite; - if (args.size() >= 3) { - layout_rewrite_level = LayoutRewriteLevel(static_cast((args[2]))); - } te::Schedule sch; Array return_tensors; - std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level); + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); *ret = Array{sch, return_tensors}; }); diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index d9520e388ae0..620b28699186 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -64,13 +64,6 @@ class ComputeDAGNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); }; -enum LayoutRewriteLevel { - kNoRewrite = 0, // No layout rewrite - kPlaceholderRewrite = 1, // Only rewrite layout of placeholder in the compute dag - kComputeRewrite = 2, // Only rewrite compute body for new layout in the compute dag - kBothRewrite = 3, // Rewrite both placeholder and compute body in the compute dag -}; - /*! * \brief Managed reference to ComputeDAGNode. * \sa ComputeDAGNode @@ -83,8 +76,7 @@ class ComputeDAG: public ObjectRef { // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. // The return values can be used as arguments to tvm.build or tvm.lower std::pair > ApplySteps( - const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; + const std::vector& transform_steps) const; // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 787e4256a181..4d0ec24cbee2 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -109,7 +109,6 @@ State::State(const Array& ops) { for (const auto& op : ops) { node->stages.push_back(Stage(op)); } - node->attach_map = AttachMap(make_object()); node->complete = true; node->aux_info = ObjectRef(); data_ = std::move(node); @@ -121,7 +120,6 @@ State::State(const std::vector& stages, auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; - node->attach_map = AttachMap(make_object()); node->complete = complete; node->aux_info = std::move(aux_info); data_ = std::move(node); @@ -183,7 +181,6 @@ std::vector State::DoSplitStepCommon( bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; - size_t old_iter_size = stage->iters.size(); PrimExpr tosplit_min, tosplit_extent; if (it->range.defined()) { @@ -243,15 +240,6 @@ std::vector State::DoSplitStepCommon( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); - // we have to replace the iterators in attach map, - // these two vectors keep the replacement mapping - std::vector from_iters; - std::vector to_iters; - for (size_t i = iter_id; i < old_iter_size; ++i) { - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, i + lengths.size()); - } - pstate->attach_map.ReplaceIters(from_iters, to_iters); return outs; } @@ -263,7 +251,6 @@ std::vector State::DoSplitStep(const SplitStep& step) { Iterator State::DoFuseStep(const FuseStep& step) { int stage_id = step->stage_id; const Stage& stage = operator->()->stages[stage_id]; - int old_iter_size = static_cast(stage->iters.size()); std::string new_name; PrimExpr new_extent = 1; @@ -275,16 +262,6 @@ Iterator State::DoFuseStep(const FuseStep& step) { CHECK_EQ(step->fused_ids[i], step->fused_ids[i - 1] + 1); } - if (i != step->fused_ids.size() - 1) { - const auto& iter_to_attached_stage = - operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair( - stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { - LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " - "that have been attached by some stages"; - } - } - const Iterator& it = stage->iters[step->fused_ids[i]]; ori_iters.push_back(it); new_name += it->name + "@"; @@ -323,23 +300,6 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); - // we have to replace the iterators in attach map, - // these two vectors keep the replacement mapping - std::vector from_iters; - std::vector to_iters; - const int begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); - for (int i = 0; i < old_iter_size; ++i) { - if (i <= begin_id) { - continue; - } else if (i > end_id) { // move forward - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { // move to the fused id - from_iters.emplace_back(stage_id, i); - to_iters.emplace_back(stage_id, begin_id); - } - } - pstate->attach_map.ReplaceIters(from_iters, to_iters); return new_it; } @@ -447,17 +407,6 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, indent += 2; } - - if (state != nullptr) { - AttachMap::IterKey iter_key(stage_id, i); - auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); - if (pair != state->attach_map->iter_to_attached_stages.end()) { - for (const auto& attach_stage_id : pair->second) { - PrintStage(os, attach_stage_id, state, base_indent + indent, - delete_trivial_loop); - } - } - } } for (size_t j = 0; j < base_indent + indent; ++j) { @@ -506,94 +455,6 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, - int target_iter_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the current entry of stage - DeleteStageEntry(pnode, stage_id); - - // store the new relation - IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = - std::make_pair(target_stage_id, target_iter_id); - pnode->iter_to_attached_stages[iter_key].push_back(stage_id); -} - -void AttachMap::DeleteStage(int stage_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the entry of old stage - DeleteStageEntry(pnode, stage_id); -} - -void AttachMap::ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters) { - AttachMapNode* pnode = CopyOnWrite(); - - CHECK_EQ(old_iters.size(), new_iters.size()); - for (size_t i = 0; i < old_iters.size(); ++i) { - auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); - if (entry == pnode->iter_to_attached_stages.end()) { - continue; - } - - // replace iter in the value of `stage_to_attach_iter` - for (const auto& s : entry->second) { - pnode->stage_to_attach_iter[s] = new_iters[i]; - } - - // replace iter in the key of `iter_to_attached_stages` - std::vector attached_stages = std::move(entry->second); - pnode->iter_to_attached_stages.erase(entry); - pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); - } -} - -void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { - auto old_entry = pnode->stage_to_attach_iter.find(stage_id); - if (old_entry != pnode->stage_to_attach_iter.end()) { - // delete value in `iter_to_attached_stages` - auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); - DeleteItem(&entry2->second, stage_id); - if (entry2->second.size() == 0) { - pnode->iter_to_attached_stages.erase(entry2); - } - // delete key in `stage_to_attach_iter` - pnode->stage_to_attach_iter.erase(old_entry); - } -} - -AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMap(make_object()); - auto pmap = map.CopyOnWrite(); - for (const auto& x : operator->()->stage_to_attach_iter) { - auto key = x.first; - if (key >= start_id) { - key += offset; - } - auto value = x.second; - if (value.first >= start_id) { - value.first += offset; - } - pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); - } - for (const auto& x : operator->()->iter_to_attached_stages) { - auto key = x.first; - if (key.first >= start_id) { - key.first += offset; - } - auto value = x.second; - for (auto& i : value) { - if (i >= start_id) { - i += offset; - } - } - pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); - } - return map; -} - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 2d6c85db0247..fbd1a8a8a91c 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -163,46 +163,6 @@ class Stage : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); }; -/*! \brief stores the compute_at relation between stages - * This stores a bi-directional mapping from stages and iter: - * 1. Stage to its attached iterator 2. Iterator to the stage attached to it - * - * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages - * to query the relations */ -class AttachMapNode: public Object { - public: - using StageKey = int; - using IterKey = std::pair; // stage_id and iter_id - - std::unordered_map stage_to_attach_iter; - std::unordered_map> iter_to_attached_stages; - - static constexpr const char* _type_key = "ansor.AttachMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); -}; - -/*! - * \brief Managed reference to AttachMapNode. - * \sa AttachMapNode - */ -class AttachMap : public ObjectRef { - public: - using StageKey = int; - using IterKey = std::pair; // stage_id and iter_id - - void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); - void DeleteStage(int stage_id); - void ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters); - AttachMap ApplyStageIdOfffset(int start_id, int offset) const; - - TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); - - private: - static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); -}; - /*! \brief The base class for a transformation step */ class StepNode: public Object { public: @@ -229,7 +189,6 @@ class StateNode: public Object { std::vector stages; // Current stages and loop structures std::vector transform_steps; // History transformation steps bool complete; // Indicate whether this state has unfilled tile sizes - AttachMap attach_map; // stores the compute_at relation between stages ObjectRef aux_info; // Used to store any auxiliary info about this state ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. // The default value is an empty NodeRef @@ -263,16 +222,7 @@ class State : public ObjectRef { bool inner_to_outer = true); Iterator fuse(int stage_id, const std::vector& iters); - /* Do transform steps - * Note: The following functions only change loop state but do not change transform_history. - * We separate these functions out, - * so you can call them for replay easily given history steps */ - void DoReorderStep(const ReorderStep& step); - std::vector DoSplitStep(const SplitStep& step); - Iterator DoFuseStep(const FuseStep& step); - // General do step functions with a runtime dynamic dispatcher - void DoStep(const Step& step, const ComputeDAG& dag); void DoSteps(const std::vector& step, const ComputeDAG& dag); // Print the state to a string @@ -282,6 +232,15 @@ class State : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); private: + void DoStep(const Step& step, const ComputeDAG& dag); + + /* Do transform steps + * Note: The following functions only change loop state but do not change transform_history. + * We separate these functions out, + * so you can call them for replay easily given history steps */ + void DoReorderStep(const ReorderStep& step); + std::vector DoSplitStep(const SplitStep& step); + Iterator DoFuseStep(const FuseStep& step); // Common function for DoSplitStep and DoFollowSplitStep std::vector DoSplitStepCommon(int stage_id, int iter_id, const std::vector& lengths, diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc deleted file mode 100644 index 36ac46f49551..000000000000 --- a/tests/cpp/ansor_test.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -// todo(jcf94): do not use relative path -#include "../../src/ansor/loop_state.h" - -// Compute declaration for test -tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, - int CI, int CO, - int kernel_size, - int strides, int padding, - int dilation = 1) { - using namespace tvm; - using namespace tvm::te; - - Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); - Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, - DataType::Float(32), "Kernel"); - Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias"); - Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); - Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); - - int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; - int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; - - const auto& conv = - topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); - CHECK(conv->shape[2].as()->value == OH); - CHECK(conv->shape[3].as()->value == OW); - - const auto& bias_add = compute( - {N, CO, OH, OW}, - [&](Var i, Var j, Var k, Var l) { - return conv[i][j][k][l] + bias[j][0][0]; - }, - "Bias_add"); - const auto& bn_mul = compute( - {N, CO, OH, OW}, - [&](Var i, Var j, Var k, Var l) { - return bias_add[i][j][k][l] * bn_scale[j][0][0]; - }, - "Bn_mul"); - const auto& bn_add = compute( - {N, CO, OH, OW}, - [&](Var i, Var j, Var k, Var l) { - return bn_mul[i][j][k][l] + bn_offset[j][0][0]; - }, - "Bn_add"); - const auto& out = topi::relu(bn_add); - - return {data, kernel, bias, bn_scale, bn_offset, out}; -} - -using namespace tvm::ansor; - -// Test Access Analyzer -TEST(ComputeDAG, GetProducersConsumers) { - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAG(tensors); - int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; - int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; - - State s0 = dag.GetInitState(); - std::unordered_set set; - { - std::vector> consumer_list = { - {data, padding}, {padding, conv}, {kernel, conv}, - {conv, bias_add}, {bias, bias_add}, {bias_add, bn_mul}, - {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, - {bn_add, relu}}; - for (const auto& pair : consumer_list) { - dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); - CHECK_EQ(set.size(), 1); - CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); - } - std::vector>> producer_list = { - {padding, {data}}, - {conv, {padding, kernel}}, - {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, - {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}}}; - for (const auto& pair : producer_list) { - dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); - CHECK_EQ(set.size(), pair.second.size()); - for (const auto& target : pair.second) { - CHECK(set.count(s0->stages[target]->op)); - } - } - } - - s0.compute_inline(bn_add); - s0.compute_inline(bn_mul); - s0.compute_inline(bias_add); - s0.compute_inline(padding); - { - std::vector> consumer_list = { - {data, conv}, {kernel, conv}, {conv, relu}}; - for (const auto& pair : consumer_list) { - dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); - CHECK_EQ(set.size(), 1); - CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); - } - std::vector>> producer_list = { - {padding, {data}}, - {conv, {padding, kernel}}, - {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, - {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}}}; - for (const auto& pair : producer_list) { - dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); - CHECK_EQ(set.size(), pair.second.size()); - for (const auto& target : pair.second) { - CHECK(set.count(s0->stages[target]->op)); - } - } - } -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} From 0c200cd5625cb9c0dc020c529c6e1d949a5e9cd2 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 28 Jun 2020 15:35:09 +0800 Subject: [PATCH 44/78] Doc update Update statenode::stages from vector to Array --- python/tvm/ansor/auto_schedule.py | 6 +- python/tvm/ansor/workload_registry.py | 27 ++- src/ansor/auto_schedule.cc | 19 +- src/ansor/auto_schedule.h | 63 ++++-- src/ansor/compute_dag.cc | 60 ++--- src/ansor/compute_dag.h | 89 ++++++-- src/ansor/loop_state.cc | 52 +++-- src/ansor/loop_state.h | 275 ++++++++++++++++------- src/ansor/measure.cc | 21 +- src/ansor/measure.h | 243 +++++++++++++++----- src/ansor/search_policy/empty_policy.cc | 34 +-- src/ansor/search_policy/empty_policy.h | 24 +- src/ansor/search_policy/search_policy.cc | 14 +- src/ansor/search_policy/search_policy.h | 81 +++++-- src/ansor/search_task.cc | 4 +- src/ansor/search_task.h | 46 +++- src/ansor/serialization.cc | 10 +- src/ansor/serialization.h | 48 +++- src/ansor/transform_step.cc | 6 +- src/ansor/transform_step.h | 121 ++++++++-- src/ansor/utils.cc | 2 +- src/ansor/utils.h | 23 +- 22 files changed, 875 insertions(+), 393 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 8fddac567529..750c3743c0eb 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -93,7 +93,7 @@ class TuneOption(Object): Number of total measurement trials early_stopping: int Stops early the tuning if no improvement after n measurements - num_measure_per_iter: int + num_measure_per_round: int The number of programs to be measured at each iteration verbose: int Verbosity level. 0 means silent. @@ -111,7 +111,7 @@ class TuneOption(Object): - ansor.PreloadMeasuredStates - ansor.PreloadCustomSketchRule """ - def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, + def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_round=64, verbose=1, builder='local', runner='local', measure_callbacks=None, pre_search_callbacks=None): if isinstance(builder, str): @@ -133,7 +133,7 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, pre_search_callbacks = [] self.__init_handle_by_constructor__( - _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter, + _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 025b5f03c661..d6df6f36f046 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -29,8 +29,7 @@ When we need the dag, we decode the string and call the function, which will return the dag. """ -from typing import List, Tuple, Callable, Union -from collections import Hashable +from typing import Hashable import pickle import json import hashlib @@ -43,7 +42,7 @@ WORKLOAD_FUNC_REGISTRY = {} -def register_workload_func(func: Callable): +def register_workload_func(func): """Register a workload generation function The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. @@ -65,7 +64,7 @@ def matmul(N, M, K): return func -def compute_dag_hash(dag: ComputeDAG): +def compute_dag_hash(dag): """ Get hash value for a ComputeDAG """ # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG @@ -87,7 +86,7 @@ def compute_dag_hash(dag: ComputeDAG): return hashlib.md5(str_key).hexdigest() -def register_workload_bufs(bufs: List[Tensor]) -> str: +def register_workload_bufs(bufs): """Directly register buffers of a workload and return the workload_key The buffers can be looked up with workload_key_to_tensors by the workload_key """ @@ -97,13 +96,13 @@ def register_workload_bufs(bufs: List[Tensor]) -> str: return json.dumps((key,)) -def list_to_tuple(x: List) -> Tuple: +def list_to_tuple(x): """Convert a list to a tuple recursively""" assert isinstance(x, list) return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) -def serialize_args(args: Tuple) -> Tuple: +def serialize_args(args): """ Serialize arguments of a function to a hashable and jsonable tuple. Currently this is mainly used for tvm.tensor.Tensor @@ -121,7 +120,7 @@ def serialize_args(args: Tuple) -> Tuple: return tuple(ret) -def deserialize_args(args: Tuple) -> List: +def deserialize_args(args): """The inverse function of :code:`serialize_args`""" ret = [] for t in args: @@ -133,7 +132,7 @@ def deserialize_args(args: Tuple) -> List: @tvm._ffi.register_func("ansor.workload_key_to_tensors") -def workload_key_to_tensors(workload_key: str) -> List[Tensor]: +def workload_key_to_tensors(workload_key): """Decode a workload key to the input/output tensors""" workload = json.loads(workload_key) name = workload[0] @@ -146,13 +145,13 @@ def workload_key_to_tensors(workload_key: str) -> List[Tensor]: @ tvm._ffi.register_func("ansor.workload_key_to_dag") -def workload_key_to_dag(workload_key: str) -> ComputeDAG: +def workload_key_to_dag(workload_key): """Decode a workload key to a compute dag""" tensors = workload_key_to_tensors(workload_key) return ComputeDAG(tensors) -def make_workload_key_func(func: Union[str, Callable], args: Tuple) -> str: +def make_workload_key_func(func, args): """make a workload key from function and arguments""" args = serialize_args(args) @@ -169,21 +168,21 @@ def make_workload_key_func(func: Union[str, Callable], args: Tuple) -> str: return json.dumps((func_name,) + args) -def make_workload_key_bufs(bufs: List[Tensor]) -> str: +def make_workload_key_bufs(bufs): """make a workload key from bufs""" dag = ComputeDAG(bufs) key = compute_dag_hash(dag) return json.dumps((key,)) -def dump_workload_func_registry(filename: str): +def dump_workload_func_registry(filename): """Dump workload function registry to a pickle binary file""" global WORKLOAD_FUNC_REGISTRY pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb')) -def load_workload_func_registry(filename: str): +def load_workload_func_registry(filename): """Load workload function registry from a pickle binary file""" global WORKLOAD_FUNC_REGISTRY diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 82ec07930adc..a2e3b7c11f4e 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -19,7 +19,7 @@ /*! * \file ansor/auto_schedule.cc - * \brief The user interface of the auto-scheduler + * \brief The user interface of the Ansor auto-scheduler. */ #include "auto_schedule.h" @@ -33,13 +33,13 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(TuneOptionNode); TuneOption::TuneOption(int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, Builder builder, + int num_measure_per_round, int verbose, Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks) { auto node = make_object(); node->n_trials = n_trials; node->early_stopping = early_stopping; - node->num_measure_per_iter = num_measure_per_iter; + node->num_measure_per_round = num_measure_per_round; node->verbose = verbose; node->builder = std::move(builder); node->runner = std::move(runner); @@ -50,17 +50,16 @@ TuneOption::TuneOption(int n_trials, int early_stopping, std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { - // Search for the best schedule + // Create a ProgramMeasurer to handle the schedule build and performance measure ProgramMeasurer measurer = ProgramMeasurer(tune_option->builder, tune_option->runner, tune_option->measure_callbacks, tune_option->verbose); - + // Search for the best schedule State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, - tune_option->num_measure_per_iter, tune_option->verbose, measurer, + tune_option->num_measure_per_round, tune_option->verbose, measurer, tune_option->pre_search_callbacks); - return task->compute_dag.ApplySteps(state->transform_steps); } @@ -68,20 +67,22 @@ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, SearchPolicy search_policy, HardwareParams hardware_params, TuneOption tune_option) { + // Create SearchTask from the given workload key ComputeDAG dag = ComputeDAG(workload_key); SearchTask task = SearchTask( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); + // Search for the best schedule return AutoSchedule(std::move(task), std::move(search_policy), std::move(tune_option)); } TVM_REGISTER_GLOBAL("ansor.TuneOption") .set_body_typed([](int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, Builder builder, + int num_measure_per_round, int verbose, Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks) { - return TuneOption(n_trials, early_stopping, num_measure_per_iter, verbose, + return TuneOption(n_trials, early_stopping, num_measure_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks); }); diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 7ffd2c4d3a70..9df15519b419 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -19,7 +19,7 @@ /*! * \file ansor/auto_schedule.h - * \brief The user interface of the auto-scheduler + * \brief The user interface of the Ansor auto-scheduler. */ #ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ @@ -33,23 +33,30 @@ namespace tvm { namespace ansor { -/*! \brief Tuning and measurement options */ +/*! \brief Tuning and measurement options. */ class TuneOptionNode : public Object { public: - int n_trials; // Number of total measurement trials - int early_stopping; // Stops early the tuning if no improvement after n measurements - int num_measure_per_iter; // The number of programs to be measured at each iteration - int verbose; // Verbosity level. 0 means silent. - Builder builder; // Builder which builds the program - Runner runner; // Runner which runs the program and measure time costs - Array measure_callbacks; // MeasureCallback functions - Array pre_search_callbacks; // SearchCallback functions - // run before search + /*! \brief Number of total measurement trials. */ + int n_trials; + /*! \brief Stops early the tuning if no improvement after n measurements. */ + int early_stopping; + /*! \brief The number of programs to be measured at each search round. */ + int num_measure_per_round; + /*! \brief Verbosity level. (0 means silent) */ + int verbose; + /*! \brief Builder which builds the program */ + Builder builder; + /*! \brief Runner which runs the program and measure time costs */ + Runner runner; + /*! \brief MeasureCallback functions to be called after each measure batch */ + Array measure_callbacks; + /*! \brief SearchCallback functions to be called before schedule search */ + Array pre_search_callbacks; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("n_trials", &n_trials); v->Visit("early_stopping", &early_stopping); - v->Visit("num_measure_per_iter", &num_measure_per_iter); + v->Visit("num_measure_per_round", &num_measure_per_round); v->Visit("verbose", &verbose); v->Visit("builder", &builder); v->Visit("runner", &runner); @@ -67,7 +74,18 @@ class TuneOptionNode : public Object { */ class TuneOption : public ObjectRef { public: - TuneOption(int n_trials, int early_stopping, int num_measure_per_iter, + /*! + * \brief The constructor + * \param n_trials Number of total measurement trials. + * \param early_stopping Stops early the tuning if no improvement after n measurements. + * \param num_measure_per_round The number of programs to be measured at each search round. + * \param verbose Verbosity level. (0 means silent) + * \param builder Builder which builds the program. + * \param runner Runner which runs the program and measure time costs. + * \param measure_callbacks MeasureCallback functions to be called after each measure batch. + * \param pre_search_callbacks SearchCallback functions to be called before schedule search. + */ + TuneOption(int n_trials, int early_stopping, int num_measure_per_round, int verbose, Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks); @@ -75,11 +93,26 @@ class TuneOption : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); }; -/*! \brief Auto schedule for a compute declaration */ +/*! + * \brief Auto schedule search for a given compute declaration, by SearchTask. + * \param task The target search task. + * \param search_policy The search policy to be used for schedule search. + * \param tune_option Tuning and measurement options. + * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. + */ std::pair > AutoSchedule( SearchTask task, SearchPolicy search_policy, TuneOption tune_option); -/*! \brief Auto schedule for a compute declaration */ +/*! + * \brief Auto schedule search for a given compute declaration, by workload key. + * \param workload_key The target workload key. + * \param target A `tvm::target`. + * \param target_host A `tvm::target` for host device. + * \param search_policy The search policy to be used for schedule search. + * \param hardware_params Hardware parameters. + * \param tune_option Tuning and measurement options. + * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build` + */ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, SearchPolicy search_policy, HardwareParams hardware_params, diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index fc37a09872ea..b9a83733c116 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -19,15 +19,17 @@ /*! * \file ansor/compute_dag.cc - * \brief Compute declaration graph and its related analysis tools + * \brief Compute declaration graph and its related analysis tools. */ #include "compute_dag.h" + #include #include #include #include #include + #include #include #include @@ -36,7 +38,9 @@ #include #include #include -#include "transform_step.h" + +#include "loop_state.h" +#include "utils.h" namespace tvm { namespace ansor { @@ -45,6 +49,23 @@ using namespace tvm::tir; TVM_REGISTER_NODE_TYPE(ComputeDAGNode); +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { + if (auto pop = stage->op.as()) { + std::vector& axes = (*stage_to_axes)[stage]; + axes.clear(); + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + } else if (stage->op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops void TopoSortOps(const Array& tensors, @@ -228,7 +249,6 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { } else { LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; } - auto node = make_object(); FlopEstimator estimator; node->tensors = std::move(tens); @@ -240,27 +260,6 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { data_ = std::move(node); } -std::string BaseName(const std::string& str) { - return str.substr(0, str.rfind("_")); -} - -void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { - if (auto pop = stage->op.as()) { - std::vector& axes = (*stage_to_axes)[stage]; - axes.clear(); - for (const auto& axis : pop->axis) { - axes.push_back(axis); - } - for (const auto& axis : pop->reduce_axis) { - axes.push_back(axis); - } - } else if (stage->op->IsInstance()) { - {} // do nothing - } else { - LOG(FATAL) << "Invalid op " << stage->op; - } -} - std::pair > ComputeDAG::ApplySteps( const std::vector& transform_steps) const { std::vector stages; @@ -300,7 +299,7 @@ std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_st << " + " << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } - + // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); @@ -360,11 +359,13 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { Array tensors; Map bounds; + // Replay steps to tvm::Schedule std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize(); bounds = te::InferBound(sch); + // Update the state bound information for (size_t i = 0; i < pstate->stages.size(); ++i) { const Stage& stage = pstate->stages[i]; @@ -388,8 +389,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } } - pstate->stages[i] = Stage(stage->op, stage->op_type, std::move(new_iters), - stage->compute_at, stage->attrs); + pstate->stages.Set(i, Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs)); } } @@ -427,7 +428,10 @@ std::pair > ComputeDAG::ReplaySteps( if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { break; } - + // Call each step's ApplyToSchedule method + // Note: some steps have extra parameters that must be passed and they may need different + // return value, so the ApplyToSchedule is not able to be merged to single interface like + // PrintAsPythonAPI does if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 620b28699186..0d1473126ad6 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -19,7 +19,7 @@ /*! * \file ansor/compute_dag.h - * \brief Compute declaration graph and its related analysis tools + * \brief Compute declaration graph and its related analysis tools. */ #ifndef TVM_ANSOR_COMPUTE_DAG_H_ @@ -27,12 +27,12 @@ #include #include + #include #include #include #include #include -#include "utils.h" namespace tvm { namespace ansor { @@ -42,17 +42,24 @@ class StateNode; class State; class Step; typedef std::unordered_map, ObjectHash, ObjectEqual> StageToAxesMap; -// Update StageToAxes Map during replay +/*! + * \brief Update stage and axes mapping during replay. + * \param stage A `te::Stage`. + * \param stage_to_axes A pointer to StageToAxesMap. + */ void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); - /*! \brief Computation declaration graph */ class ComputeDAGNode : public Object { public: - Array tensors; // Input and output tensors - Array ops; // All related operations in topo order - double flop_ct; // Number of float operations - ObjectRef init_state; // The initial state + /*! \brief Input and output tensors. */ + Array tensors; + /*! \brief All related operations in topo order. */ + Array ops; + /*! \brief Number of float operations. */ + double flop_ct; + /*! \brief The initial state. */ + ObjectRef init_state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tensors", &tensors); @@ -70,40 +77,76 @@ class ComputeDAGNode : public Object { */ class ComputeDAG: public ObjectRef { public: + /*! \brief The constructor. + * \param tensors `te::Tensor`s for a compute declaration. + */ explicit ComputeDAG(Array tensors); + /*! \brief The constructor. + * \param workload_key Workload key for a compute declaration. + */ explicit ComputeDAG(const std::string& workload_key); - // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. - // The return values can be used as arguments to tvm.build or tvm.lower + /*! + * \brief Apply transform steps to the init state of this DAG, and get the + * equivalent `tvm::schedule`. + * \param transform_steps Transform steps of the target state. + * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. + */ std::pair > ApplySteps( const std::vector& transform_steps) const; - - // Print transform steps as equivalent python schedule API - std::string PrintStepsAsPython(const std::vector& steps) const; - - // Replay the transform steps and call ir_pass::InferBound to fill correct bound information + /*! + * \brief Print transform steps as equivalent python schedule API. + * \param transform_steps Transform steps of the target state. + * \return Python schedule code. + */ + std::string PrintStepsAsPython(const std::vector& transform_steps) const; + + /*! + * \brief Replay the transform steps and call ir_pass::InferBound to fill + * correct bound information. + * \param transform_steps Transform steps of the target state. + * \return The State after inferbound. + */ State ReplayAndInferBound(const std::vector& transform_steps) const; - - // Fill the correct bound information for a given state by calling ir_pass::InferBound + /*! + * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. + * \param state The target state. + * \return The State after inferbound. + */ State InferBound(const State& state) const; - - // Fill the correct bound information for a list of given states. - // Return the new states inplace + /*! + * \brief Fill the correct bound information for a list of given states. + * Return the new states inplace. + * \param states A pointer to a State vector. + */ void InferBound(std::vector* states) const; - // Get the init state + /*! + * \brief Get the init state. + * \return The init state. + */ State GetInitState() const; TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); private: - // Internal common parts for replaying steps + /*! + * \brief Internal common parts for replaying steps. + * \param transform_steps Transform steps of the target state. + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. + */ std::pair > ReplaySteps( const std::vector& transform_steps, std::vector* stages, StageToAxesMap* stage_to_axes) const; - // Internal common parts for inferring bound + /*! + * \brief Internal common parts for inferring bound. + * \param pstate A pointer to StateNode, the target state will be updated with filled + * bound information. + */ void InferBoundCommon(StateNode* pstate) const; }; diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 4d0ec24cbee2..3843e7954500 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -19,7 +19,7 @@ /*! * \file ansor/loop_state.cc - * \brief An lightweight IR (intermediate representation) for loop structures. + * \brief An lightweight IR (intermediate representation) for loop structures. * see ansor/loop_state.h for more explanation. */ @@ -37,7 +37,7 @@ TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); -// Maker for other classes +/********** Iterator **********/ Iterator::Iterator(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, const std::vector* ori_iters, @@ -54,6 +54,7 @@ Iterator::Iterator(std::string name, Range range, IteratorType iter_type, data_ = std::move(node); } +/********** Stage **********/ Stage::Stage(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { @@ -104,28 +105,26 @@ Stage::Stage(te::Operation op, StageType op_type, std::vector&& iters, data_ = std::move(node); } +/********** State **********/ State::State(const Array& ops) { auto node = make_object(); for (const auto& op : ops) { node->stages.push_back(Stage(op)); } node->complete = true; - node->aux_info = ObjectRef(); data_ = std::move(node); } State::State(const std::vector& stages, - const std::vector& transform_steps, bool complete, - ObjectRef aux_info) { + const std::vector& transform_steps, bool complete) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; node->complete = complete; - node->aux_info = std::move(aux_info); data_ = std::move(node); } -// Schedule primitives api +/********** Schedule primitives apis for state **********/ void State::reorder(int stage_id, const std::vector& order) { const Stage& stage = operator->()->stages[stage_id]; @@ -160,7 +159,7 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { return DoFuseStep(step); } -// Steps' implementations +/********** Step implementations for state **********/ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -170,9 +169,9 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = Stage( + pstate->stages.Set(step->stage_id, Stage( stage->op, stage->op_type, std::move(iters), stage->compute_at, - stage->attrs); + stage->attrs)); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -236,9 +235,9 @@ std::vector State::DoSplitStepCommon( stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = Stage( + pstate->stages.Set(stage_id, Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->attrs); + stage->attrs)); return outs; } @@ -296,25 +295,13 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = Stage( + pstate->stages.Set(stage_id, Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->attrs); + stage->attrs)); return new_it; } -void State::DoStep(const Step& step, const ComputeDAG& dag) { - if (auto ps = step.as()) { - DoReorderStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoSplitStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoFuseStep(GetRef(ps)); - } else { - LOG(FATAL) << "Invalid step: " << step; - } -} - void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); @@ -328,10 +315,19 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { if (complete_rate >= 0 && ct++ > steps.size() * complete_rate) { break; } - DoStep(step, dag); + if (auto ps = step.as()) { + DoReorderStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFuseStep(GetRef(ps)); + } else { + LOG(FATAL) << "Invalid step: " << step; + } } } +// Print stage to ostream void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; @@ -415,6 +411,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, *os << stage->op->name << " = ...\n"; } +// Print state to ostream void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { // Gather placeholders @@ -461,6 +458,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PrintState(&p->stream, node, true); }); +/********** State interface API for ffi **********/ TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { return Array(stage->iters); }); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index fbd1a8a8a91c..723f0b78fb04 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -19,10 +19,11 @@ /*! * \file ansor/loop_state.h - * \brief The definition of the "state" in search. A state consists a current loop structure + * \brief The definition of the "state" in search. A state consists the current loop structure * and the transform history to reach its current loop structure. * To enable flexible manipulation of the loop structure, we implemented a lightweight - * loop structure IR (Intermediate Representation) specifically for search. + * loop structure IR (Intermediate Representation) specifically for search. This can be seen as + * a preview of how this schedule looks like after tvm.lower or tvm.build. * * Basically this is a simplified TVM IR with schedule primitives. * We don't use the existing TVM IR because @@ -31,9 +32,9 @@ * 3. We may create some macro schedule primitives * * After the search is done, we will lower this IR to TVM IR with TVM schedule primitives. - * Because we share a lot common objects during search, the transformation is - * implemented in copy on write style. All objects are immutable, which is - * similar to TVM IR. + * Because we share a lot common objects during search, the transformation is + * implemented in copy on write style. + * All objects are immutable, which is similar to TVM IR. */ #ifndef TVM_ANSOR_LOOP_STATE_H_ @@ -44,39 +45,67 @@ #include #include #include + #include "compute_dag.h" +#include "transform_step.h" namespace tvm { namespace ansor { using namespace tvm::tir; -/*! \brief The type of a stage */ +/*! \brief The type of a stage. */ enum StageType { - kPlaceholder, // A placeholder stage - kCompute // A compute stage + /*! \brief A placeholder stage. */ + kPlaceholder = 0, + /*! \brief A compute stage. */ + kCompute = 1 }; -/*! \brief The type of compute location */ +/*! \brief The type of compute location. */ enum ComputeAtType { - kRoot, // compute at root - kInlined, // inlined - kIter, // compute at some iterator + /*! \brief Compute at root. */ + kRoot = 0, + /*! \brief Compute inlined. */ + kInlined = 1, + /*! \brief Compute at some iterator. */ + kIter = 2, }; -/*! \brief The type of an iterator */ +/*! \brief The type of an iterator. */ enum IteratorType { - kSpace, // spatial iterator - kReduce, // reduction iterator - kMixed, // fused spatial and reduction iterator - kSpecial // special iterator (e.g. virtual root iterator) + /*! \brief Spatial iterator. */ + kSpace = 0, + /*! \brief Reduction iterator. */ + kReduce = 1, + /*! \brief Fused spatial and reduction iterator. */ + kMixed = 2, + /*! \brief Special iterator. (e.g. virtual root iterator) */ + kSpecial = 3 }; -/*! \brief The type of an iterator's annotation */ +/*! \brief The type of an iterator's annotation. */ enum IteratorAnnotation { - kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY, - kTensorized + /*! \brief This iterator has no annotation. */ + kNone = 0, + /*! \brief This iterator has been unrolled. */ + kUnroll = 1, + /*! \brief This iterator has been vectorized. */ + kVectorize = 2, + /*! \brief This iterator has been paralleld. */ + kParallel = 3, + /*! \brief This iterator has been bind to vthread. */ + kVThread = 4, + /*! \brief This iterator has been bind to blockIdx.x. */ + kBlockX = 5, + /*! \brief This iterator has been bind to threadIdx.x. */ + kThreadX = 6, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockY = 7, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadY = 8, + /*! \brief This iterator has been mapped with a tensorize intrinsic. */ + kTensorized = 9 }; // forward declaration @@ -88,12 +117,18 @@ class Iterator; */ class IteratorNode : public Object { public: + /*! \brief The name of this iterator. */ std::string name; + /*! \brief The target range of this iterator. */ Range range; + /*! \brief The iterator type of this iterator. */ IteratorType iter_type; + /*! \brief The annotation type of this iterator. */ IteratorAnnotation annotation; - std::vector ori_iters; // The original iterators before fusion - std::string attr; // Todo(jcf94): Document this + /*! \brief The original iterators before fusion. */ + std::vector ori_iters; + /*! \brief The extra attribute of this iterator. */ + std::string attr; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); @@ -111,6 +146,15 @@ class IteratorNode : public Object { */ class Iterator : public ObjectRef { public: + /*! + * \brief The constructor. + * \param name The name of this iterator. + * \param range The target range of this iterator. + * \param iter_type The iterator type of this iterator. + * \param annotation The annotation type of this iterator. + * \param ori_iters The original iterators before fusion. + * \param attr The extra attribute of this iterator. + */ Iterator(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, const std::vector* ori_iters = nullptr, @@ -119,23 +163,30 @@ class Iterator : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); }; -/*! \brief Stage-level attributes */ +/*! \brief Stage-level attributes. */ struct StageAttributes { - int auto_unroll_max_step; // The maximum steps for the pragma `auto_unroll_max_step` - int storage_offset; // The storage offset for the schedule primitive `storage_align` + /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ + int auto_unroll_max_step; + /*! \brief The storage offset for the schedule primitive `storage_align`. */ + int storage_offset; }; /*! - * \brief A stage in the compute declaration - * Similar to te::Stage in `include/schedule.h` + * \brief A stage in the compute declaration. + * Similar to te::Stage in `include/schedule.h`. */ class StageNode : public Object { public: - te::Operation op; // The operator of this stage - StageType op_type; // The type of this stage - std::vector iters; // The iterators in this stage - ComputeAtType compute_at; // The compute location of this stage - StageAttributes attrs; // Other stage-level attributes + /*! \brief The operator of this stage */ + te::Operation op; + /*! \brief The type of this stage. */ + StageType op_type; + /*! \brief The iterators in this stage. */ + std::vector iters; + /*! \brief The compute location of this stage. */ + ComputeAtType compute_at; + /*! \brief Other stage-level attributes. */ + StageAttributes attrs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); @@ -151,10 +202,30 @@ class StageNode : public Object { */ class Stage : public ObjectRef { public: + /*! + * \brief The constructor. + * \param op A `te::Operation`. + */ explicit Stage(te::Operation op); + /*! + * \brief The constructor. + * \param op A `te::Operation`. + * \param op_type The stage type of this op. + * \param iters The iterators of this op. (copy) + * \param compute_at The compute at type of this op. + * \param attrs Other stage-level attributes. + */ Stage(te::Operation op, StageType op_type, const std::vector& iters, ComputeAtType compute_at, StageAttributes attrs); + /*! + * \brief The constructor. + * \param op A `te::Operation`. + * \param op_type The stage type of this op. + * \param iters The iterators of this op. (move) + * \param compute_at The compute at type of this op. + * \param attrs Other stage-level attributes. + */ Stage(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, StageAttributes attrs); @@ -163,40 +234,28 @@ class Stage : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); }; -/*! \brief The base class for a transformation step */ -class StepNode: public Object { - public: - int stage_id; - - // Print step as equivalent python schedule API - virtual std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const = 0; - - static constexpr const char* _type_key = "ansor.Step"; - TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); -}; -TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); - -// Step forward decelerations -class ReorderStep; class SplitStep; class FuseStep; - -/*! \brief A state in the search process. - * It consists of the current loop structure and the history steps to reach this state. */ +/*! + * \brief A state in the search process. + * It consists of the current loop structure and the history steps to reach this state. + */ class StateNode: public Object { public: - std::vector stages; // Current stages and loop structures - std::vector transform_steps; // History transformation steps - bool complete; // Indicate whether this state has unfilled tile sizes - ObjectRef aux_info; // Used to store any auxiliary info about this state - ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. - // The default value is an empty NodeRef - // (means no modification to the DAG) + /*! \brief Current stages and loop structures. */ + Array stages; + /*! \brief History transformation steps. */ + std::vector transform_steps; + /*! \brief Indicate whether this state has unfilled tile sizes. */ + bool complete; + /*! + * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the + * stage structure of the ComputeDAG, for exp. CacheReadStep/CacheWriteStep(Will be added later). + * The default value is an empty NodeRef. (means no modification to the original DAG) + */ + ComputeDAG task_dag; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("stages", &stages); v->Visit("complete", &complete); - v->Visit("aux_info", &aux_info); v->Visit("task_dag", &task_dag); } @@ -210,53 +269,99 @@ class StateNode: public Object { */ class State : public ObjectRef { public: + /*! + * \brief The constructor. + * \param ops `te::Operation`s for a compute declaration. + */ explicit State(const Array& ops); + /*! + * \brief The constructor. + * \param stages Stages of the target state. + * \param transform_steps Transform steps of the target state. + * \param complete Indicate whether this state has unfilled tile sizes. + */ State(const std::vector& stages, - const std::vector& transform_steps, bool complete, - ObjectRef aux_info); + const std::vector& transform_steps, bool complete); - // Schedule primitives + /*! + * \brief Schedule primitive corresponds to te.reorder. + * \param stage_id The index of the target stage. + * \param order The target iterator order. + */ void reorder(int stage_id, const std::vector& order); + /*! + * \brief Schedule primitive corresponds to te.split. + * \param stage_id The index of the target stage. + * \param it The target iterator. + * \param lengths The target split factors. Can be None to be filled by search policy. + * \param inner_to_outer True for split from inner to outer & False for outer to inner. + * \return The iterator results after split. + */ std::vector split(int stage_id, const Iterator& it, const std::vector& lengths, bool inner_to_outer = true); + /*! + * \brief Schedule primitive corresponds to te.fuse. + * \param stage_id The index of the target stage. + * \param iters The target iterators to be fused. + * \return The iterator result after fuse. + */ Iterator fuse(int stage_id, const std::vector& iters); - // General do step functions with a runtime dynamic dispatcher - void DoSteps(const std::vector& step, const ComputeDAG& dag); - - // Print the state to a string + /*! + * \brief General do step functions with a runtime dynamic dispatcher. + * \param steps The target transform steps. + * \param dag The target ComputeDAG. + */ + void DoSteps(const std::vector& steps, const ComputeDAG& dag); + + /*! + * \brief Print the state to a string. + * \param delete_trivial_loop True for skipping the trivial loops. + * (undefined or extent == 1, default set to True) + * \return The human readable state structure. + */ std::string ToStr(bool delete_trivial_loop = true) const; TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); private: - void DoStep(const Step& step, const ComputeDAG& dag); - /* Do transform steps * Note: The following functions only change loop state but do not change transform_history. - * We separate these functions out, - * so you can call them for replay easily given history steps */ + * We separate these functions out, so you can call them for replay easily given history steps */ + + /*! + * \brief Apply reorder step to current state. + * \param step A ReorderStep. + */ void DoReorderStep(const ReorderStep& step); + /*! + * \brief Apply split step to current state. + * \param step A SplitStep. + * \return The iterator results after split. + */ std::vector DoSplitStep(const SplitStep& step); + /*! + * \brief Apply fuse step to current state. + * \param step A FuseStep. + * \return The iterator result after fuse. + */ Iterator DoFuseStep(const FuseStep& step); - // Common function for DoSplitStep and DoFollowSplitStep + + /*! + * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later). + * \param stage_id The index of the target stage. + * \param iter_id The index of the target iterator. + * \param lengths The target split factors. + * \param inner_to_outer The split direction. + * \return The iterator results after split. + */ std::vector DoSplitStepCommon(int stage_id, int iter_id, const std::vector& lengths, bool inner_to_outer); }; -/*! \brief Clean the name of an iterator to make it valid in python code */ -inline std::string CleanName(const std::string& str) { - std::string ret = str; - StrReplace(&ret, ".", "_"); - StrReplace(&ret, "@", "_"); - StrReplace(&ret, "outer", "o"); - StrReplace(&ret, "inner", "i"); - return ret; -} - } // namespace ansor } // namespace tvm @@ -264,6 +369,7 @@ inline std::string CleanName(const std::string& str) { // Hash and equal function for State namespace std { +/*! \brief The hash function for ansor::State. */ template <> struct hash<::tvm::ansor::State> { std::size_t operator()(const ::tvm::ansor::State& state) const { @@ -271,6 +377,7 @@ struct hash<::tvm::ansor::State> { } }; +/*! \brief The equal_to function for ansor::State. */ template <> struct equal_to<::tvm::ansor::State> { bool operator() (const ::tvm::ansor::State& lhs, diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 7802d571d702..98a439c195cc 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -19,17 +19,21 @@ /*! * \file ansor/measure.cc - * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. */ #include "measure.h" + #include #include + #include #include #include #include +#include "utils.h" + namespace tvm { namespace ansor { @@ -42,7 +46,7 @@ TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); -const char* ErrorNoToStr[] = { +static const char* ErrorNoToStr[] = { "NoError", "InstantiationError", "CompileHostError", @@ -54,7 +58,7 @@ const char* ErrorNoToStr[] = { "UnknownError", }; -// Measure input and result +/********** Measure input and result **********/ MeasureInput::MeasureInput(SearchTask task, State state) { auto node = make_object(); node->task = std::move(task); @@ -103,7 +107,7 @@ MeasureResult MeasureResultNode::copy() const { return MeasureResult(node); } -// LocalBuilder +/********** LocalBuilder **********/ LocalBuilder::LocalBuilder(int timeout, int n_parallel, const std::string& build_func) { auto node = make_object(); @@ -125,7 +129,7 @@ Array LocalBuilderNode::Build(const Array& inputs, return Array(); } -// Local Runner +/********** LocalRunner **********/ LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval) { ObjectPtr node = make_object(); @@ -151,7 +155,7 @@ Array LocalRunnerNode::Run( return Array(); } -// Program Measurer +/********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, Array callbacks, int verbose, int max_continous_error) { @@ -229,7 +233,7 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, // Call callback functions for (const auto& callback : callbacks) { - callback->callback(policy, input_batch, result_batch); + callback->Callback(policy, input_batch, result_batch); } // Store result batch @@ -264,7 +268,7 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, } } -// Printing functions +/********** Printing functions **********/ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { p->stream << "MeasureInput()"; @@ -305,6 +309,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->time_cost << ")"; }); +/********** Measure interface API for ffi **********/ TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) { return MeasureInput(task, state); }); diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 630365512eb6..e46c42c8312c 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -19,7 +19,7 @@ /*! * \file ansor/measure.h - * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. */ #ifndef TVM_ANSOR_MEASURE_H_ @@ -35,38 +35,47 @@ namespace tvm { namespace ansor { -class SearchPolicy; -class MeasureInput; class BuildResult; class MeasureResult; -class Builder; class Runner; class MeasureCallback; class ProgramMeasurer; +class SearchPolicy; class MeasureInput; class MeasureResult; -/* \brief The error code of one measurement */ +/*! \brief The error code of one measurement */ enum MeasureErrorNO { - kNoError = 0, // No error - kInstantiationError = 1, // Errors happen when apply transform steps from init state - kCompileHostError = 2, // Errors happen when compiling code on host (when build module) - kCompileDeviceError = 3, // Errors happen when compiling code on device (when load module) - kRuntimeDeviceError = 4, // Errors happen when run program on device - kWrongAnswerError = 5, // Answer is wrong when compared to a reference output - kBuildTimeoutError = 6, // Timeout during compilation - kRunTimeoutError = 7, // Timeout during run - kUnknonwError = 8, // Unknown error + /*! \brief No error. */ + kNoError = 0, + /*! \brief Errors happen when apply transform steps from init state. */ + kInstantiationError = 1, + /*! \brief Errors happen when compiling code on host. (when build module) */ + kCompileHostError = 2, + /*! \brief Errors happen when compiling code on device. (when load module) */ + kCompileDeviceError = 3, + /*! \brief Errors happen when run program on device. */ + kRuntimeDeviceError = 4, + /*! \brief Answer is wrong when compared to a reference output. */ + kWrongAnswerError = 5, + /*! \brief Timeout during compilation. */ + kBuildTimeoutError = 6, + /*! \brief Timeout during run. */ + kRunTimeoutError = 7, + /*! \brief Unknown error. */ + kUnknonwError = 8, }; -extern const char *ErrorNoToStr[]; // Inputs and results of one measurement /*! \brief Store the input of a measurement */ class MeasureInputNode: public Object { public: - SearchTask task; // The search task - State state; // The program state to be measured + /*! \brief The search task. */ + SearchTask task; + /*! \brief The program state to be measured. */ + State state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); v->Visit("state", &state); } - MeasureInput copy() const; // Do deep copy + /*! \brief Do deep copy. */ + MeasureInput copy() const; static constexpr const char* _type_key = "ansor.MeasureInput"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); @@ -78,20 +87,29 @@ class MeasureInputNode: public Object { */ class MeasureInput : public ObjectRef { public: + /*! + * \brief The constructor. + * \param task The target SearchTeask. + * \param state The target State. + */ MeasureInput(SearchTask task, State state); TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode); }; -/*! \brief Store the input of a build */ +/*! \brief Store the input of a build. */ class BuildResultNode: public Object { public: - std::string filename; // The filename of built binary file - Array args; // The arguments - int error_no; // The error code (see MeasureErrorNO). - // 0 means no error. - std::string error_msg; // The error message if there is any error - double time_cost; // The time cost of build + /*! \brief The filename of built binary file. */ + std::string filename; + /*! \brief The arguments. */ + Array args; + /*! \brief The error code. (0 means no error, see MeasureErrorNO) */ + int error_no; + /*! \brief The error message if there is any error. */ + std::string error_msg; + /*! \brief The time cost of build. */ + double time_cost; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("filename", &filename); @@ -111,20 +129,32 @@ class BuildResultNode: public Object { */ class BuildResult : public ObjectRef { public: + /*! + * \brief The constructor. + * \param filename The filename of built binary file. + * \param args The arguments. + * \param error_no The error code. + * \param error_msg The error message if there is any error. + * \param time_cost The time cost of build. + */ BuildResult(std::string filename, Array args, int error_no, std::string error_msg, double time_cost); TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); }; -/*! \brief Store the results of a measurement */ +/*! \brief Store the results of a measurement. */ class MeasureResultNode: public Object { public: - Array costs; // The time costs of execution - int error_no; // The error code (see MeasureErrorNO). - // 0 means no error. - std::string error_msg; // The error message if there is any error - double all_cost; // The time cost of build and run - double timestamp; // The time stamps of this measurement + /*! \brief The time costs of execution. */ + Array costs; + /*! \brief The error code. (0 means no error, see MeasureErrorNO) */ + int error_no; + /*! \brief The error message if there is any error. */ + std::string error_msg; + /*! \brief The time cost of build and run. */ + double all_cost; + /*! \brief The time stamps of this measurement. */ + double timestamp; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("costs", &costs); @@ -134,7 +164,8 @@ class MeasureResultNode: public Object { v->Visit("timestamp", ×tamp); } - MeasureResult copy() const; // Do deep copy + /*! \brief Do deep copy. */ + MeasureResult copy() const; static constexpr const char* _type_key = "ansor.MeasureResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); @@ -146,6 +177,14 @@ class MeasureResultNode: public Object { */ class MeasureResult : public ObjectRef { public: + /*! + * \brief The constructor. + * \param costs The time costs of execution. + * \param error_no The error code. + * \param error_msg The error message if there is any error. + * \param all_cost The time cost of build and run. + * \param timestamp The time stamps of this measurement. + */ MeasureResult(Array costs, int error_no, std::string error_msg, double all_cost, double timestamp); @@ -155,37 +194,73 @@ class MeasureResult : public ObjectRef { /*! \brief Bass class of measurement callbacks */ class MeasureCallbackNode: public Object { public: - /*! \biref Callback function that will be called on measurement input/result pairs - * after measurement */ - virtual void callback(const SearchPolicy& policy, + /*! + * \brief Callback function that will be called on measurement input/result pairs + * after measurement. + * \param policy The current search policy. + * \param inputs An Array of MeasureInput. + * \param results An Array of MeasureResult. + */ + virtual void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) = 0; static constexpr const char *_type_key = "ansor.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); }; -TVM_DEFINE_MUTABLE_OBJECT_REF(MeasureCallback, MeasureCallbackNode); + +/*! + * \brief Managed reference to MeasureCallbackNode. + * \sa MeasureCallbackNode + */ +class MeasureCallback : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; // Base class for builder and runner + /*! \brief Builder that builds the programs */ class BuilderNode: public Object { public: - int n_parallel; // The number of tasks to run in parallel - int timeout; // Timeout of a build - - /*! \biref Build programs and return results */ + /*! \brief The number of tasks to run in parallel */ + int n_parallel; + /*! \brief Timeout of a build */ + int timeout; + + /*! + * \brief Build programs and return results. + * \param inputs An Array of MeasureInput. + * \param verbose Verbosity level. (0 means silent) + * \return An Array of MeasureResult. + */ virtual Array Build(const Array& inputs, int verbose) = 0; static constexpr const char* _type_key = "ansor.Builder"; TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); }; -TVM_DEFINE_MUTABLE_OBJECT_REF(Builder, BuilderNode); -/*! \brief Runner that runs the built programs and measure the time cost */ -class RunnerNode: public Object { +/*! + * \brief Managed reference to BuilderNode. + * \sa BuilderNode + */ +class Builder : public ObjectRef { public: - int timeout; // Timeout of a run + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Builder, ObjectRef, BuilderNode); +}; - /*! \biref Run measurement and return results */ +/*! \brief Runner that runs the built programs and measure the time cost. */ +class RunnerNode: public Object { + public: + /*! \brief Timeout of a run. */ + int timeout; + + /*! + * \brief Run measurement and return results. + * \param inputs An Array of MeasureInput. + * \param build_results An Array of BuildResult. + * \param verbose Verbosity level. (0 means silent) + * \return An Array of MeasureResult. + */ virtual Array Run(const Array& inputs, const Array& build_results, int verbose) = 0; @@ -193,14 +268,23 @@ class RunnerNode: public Object { static constexpr const char* _type_key = "ansor.Runner"; TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); }; -TVM_DEFINE_MUTABLE_OBJECT_REF(Runner, RunnerNode); +/*! + * \brief Managed reference to RunnerNode. + * \sa RunnerNode + */ +class Runner : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Runner, ObjectRef, RunnerNode); +}; // Implementation of various builders and runners + /*! \brief LocalBuilder use local CPU cores to build programs in parallel */ class LocalBuilderNode: public BuilderNode { public: - std::string build_func; // Build function + /*! \brief Build function. */ + std::string build_func; Array Build(const Array& inputs, int verbose) final; @@ -214,6 +298,12 @@ class LocalBuilderNode: public BuilderNode { */ class LocalBuilder: public Builder { public: + /*! + * \brief The constructor. + * \param timeout The timeout limit for each build. + * \param n_parallel Number of threads used to build in parallel. + * \param build_func The name of registered build function. + */ LocalBuilder(int timeout, int n_parallel, const std::string& build_func); TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); @@ -222,12 +312,15 @@ class LocalBuilder: public Builder { /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: + /*! \brief Number of measure times. */ int number; + /*! \brief Number of repeat times in each measure. */ int repeat; + /*! \brief The minimum duration of one repeat in milliseconds. */ int min_repeat_ms; + /*! \brief The cool down interval between two measurements. */ double cooldown_interval; - /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -242,6 +335,14 @@ class LocalRunnerNode: public RunnerNode { */ class LocalRunner: public Runner { public: + /*! + * \brief The constructor. + * \param timeout The timeout limit for each run. + * \param number Number of measure times. + * \param repeat Number of repeat times in each measure. + * \param min_repeat_ms The minimum duration of one repeat in milliseconds. + * \param cooldown_interval The cool down interval between two measurements. + */ LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); @@ -254,35 +355,57 @@ class LocalRunner: public Runner { * This class combines Builder and Runner, and provides a simpler API */ class ProgramMeasurerNode: public Object { public: - static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; - + /*! \brief Measured programs counter. */ int ct; - int error_ct; // continuous error counter + /*! \brief Continuous error counter. */ + int error_ct; + /*! \brief Workload key to best flops map. */ std::unordered_map best_flops; + /*! \brief Workload key to best state map. */ std::unordered_map best_state; + /*! \brief Workload key to best state's count index map. */ std::unordered_map best_ct; - + /*! \brief The Builder to build each program. */ Builder builder; + /*! \brief The Runner to measure each program. */ Runner runner; + /*! \brief MeasureCallback to be called after each measure batch. */ Array callbacks; + /*! \brief Verbose level. */ int verbose; + /*! \brief The number of max continuous error. */ int max_continous_error; /*! \brief Reset book keeping variables */ void Reset(); - /*! \biref Do measurement */ + /*! + * \brief Do measurement. + * \param task The current SearchTask. + * \param policy The current SearchPolicy. + * \param inputs The target MeasureInputs. + * \param results A pointer to MeasureResult vector, this is used as output. + * \param batch_size Number of programs to be measured in one batch. + */ void Measure(const SearchTask& task, const SearchPolicy& policy, const std::vector& inputs, std::vector* results, int batch_size = -1); - - /*! \biref Do measurement silently */ + /*! + * \brief Do measurement silently. + * This API will not print the measure results to screen. + * \param task The current SearchTask. + * \param inputs The target MeasureInputs. + * \param results A pointer to MeasureResult vector, this is used as output. + */ void SilentMeasure(const SearchTask& task, const std::vector& inputs, std::vector* results); + /*! \brief The default max continuous error setting. */ + static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; + static constexpr const char* _type_key = "ansor.ProgramMeasurer"; TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); }; @@ -293,6 +416,14 @@ class ProgramMeasurerNode: public Object { */ class ProgramMeasurer : public ObjectRef { public: + /*! + * \brief The constructor. + * \param builder The Builder to build each program. + * \param runner The Runner to measure each program. + * \param callbacks MeasureCallback to be called after each measure batch. + * \param verbose Verbose level. + * \param max_continous_error The number of max continuous error. + */ ProgramMeasurer(Builder builder, Runner runner, Array callbacks, int verbose, int max_continous_error = -1); diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index ba861f333c78..fe880b6cf262 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -17,6 +17,11 @@ * under the License. */ +/*! + * \file ansor/search_policy/empty_policy.cc + * \brief This is an brief example of search policy. + */ + #include "empty_policy.h" #include @@ -27,7 +32,7 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, ProgramMeasurer measurer, + int num_measure_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { cur_task = task; @@ -35,6 +40,9 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, // This Interface is usually used to set some init status RunCallbacks(pre_search_callbacks); + // Basic design principe: `SearchOneRound()` several times to get candidate states, + // measure them and return the best one + // Measure is disabled if n_trials <= 1 if (n_trials <= 1) { const auto& res = SearchOneRound(); CHECK_GT(res.size(), 0); @@ -62,32 +70,10 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, } } -std::pair, Array > EmptyPolicyNode::ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { - // The whole process is almost the same as Search, while this function is designed to be - // called and managed by another global task scheduler - - std::vector inputs; - std::vector results; - - const auto& res = SearchOneRound(); - for (const auto& state : res) { - inputs.emplace_back(cur_task, state); - } - measurer->Measure(cur_task, GetRef(this), inputs, &results); - - // Return a pair of MeasureInput Array and MeasureResult Array - Array inputs_arr(std::make_move_iterator(inputs.begin()), - std::make_move_iterator(inputs.end())); - Array results_arr(std::make_move_iterator(results.begin()), - std::make_move_iterator(results.end())); - return std::make_pair(std::move(inputs_arr), std::move(results_arr)); -} - +// As an example policy, EmptyPolicy always returns a init state std::vector EmptyPolicyNode::SearchOneRound() { std::vector res; res.push_back(cur_task->compute_dag.GetInitState()); - // As an example policy, EmptyPolicy always return a init state return res; } diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index 5c2f52608fe0..c7cf721d4bc7 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -19,7 +19,7 @@ /*! * \file ansor/search_policy/empty_policy.h - * \brief This is an basic example of search policy + * \brief This is an brief example of search policy. */ #ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ @@ -34,33 +34,25 @@ namespace tvm { namespace ansor { /*! - * \file ansor/search_policy/empty_policy.h - * \brief This is an basic example for search policy. The EmptyPolicy will - * always generates the init state of a ComputeDAG. + * \brief The EmptyPolicy will always generates the init state of a ComputeDAG. + * This is an brief example of search policy, while can show the design of search policy, + * the formal search policy will continue to follow it. + * The key implementation for this structure is `Search()`, check `empty_policy.cc` for more + * details. */ class EmptyPolicyNode : public SearchPolicyNode { public: - /*! \brief Search and make n_trails measurements. - * \returns the best state - */ State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, + int early_stopping, int num_measure_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) final; - /*! \brief Continue search for one round. This is used by JointTuner - * \returns the measurement pairs - */ - std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - static constexpr const char *_type_key = "ansor.EmptyPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); private: /*! - * \brief Usually we need a sub function to generate several candidate states in each - * search round. + * \brief Use a sub function to generate several candidate states in each search round. * \returns Several generated states */ std::vector SearchOneRound(); diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index e7a12702ba70..8786f67edb22 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -19,7 +19,7 @@ /*! * \file ansor/search_policy/search_policy.cc - * \brief The base class for search policy + * \brief The base class for search policy. */ #include "search_policy.h" @@ -35,21 +35,11 @@ TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { for (const auto& callback : callbacks) { - callback->callback(this); + callback->Callback(this); } } } -// Search Policy -TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") -.set_body_typed([](SearchPolicy policy, SearchTask task, int num_measure, - int verbose, ProgramMeasurer measurer) { - Array inputs; - Array results; - std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); - return Array{inputs, results}; -}); - TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") .set_body_typed([](SearchPolicy policy, Array callbacks) { policy->RunCallbacks(callbacks); diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index eb4703be1914..cc54822e925f 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -19,7 +19,7 @@ /*! * \file ansor/search_policy/search_policy.h - * \brief The base class for search policy + * \brief The base class for search policy. */ #ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ @@ -38,54 +38,95 @@ namespace ansor { class SearchPolicyNode; -/*! \brief Callback function to be called before or after the search process */ +/*! + * \brief Callback function to be called by the search process. + * This interface allows to do extra initializations before schedule search or extra + * check during/after the schedule search. + */ class SearchCallbackNode : public Object { public: - virtual void callback(SearchPolicyNode* policy) = 0; + /*! + * \brief Run the registered callback function. + * \param policy A pointer to SearchPolicyNode. + */ + virtual void Callback(SearchPolicyNode* policy) = 0; static constexpr const char *_type_key = "ansor.SearchCallback"; TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); }; -TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -/*! \brief The base class for search policy */ +/*! + * \brief Managed reference to SearchCallbackNode. + * \sa SearchCallbackNode + */ +class SearchCallback : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode); +}; + +/*! + * \brief The base class for search policy. + */ class SearchPolicyNode : public Object { public: - SearchTask cur_task; // The current task - int verbose; // Verbose level (0 means silent) + /*! \brief The current search task. */ + SearchTask cur_task; + /*! + * \brief Verbose level to control the screen output during schedule search. + * (0 means silent) + */ + int verbose; void VisitAttrs(AttrVisitor* v) { v->Visit("cur_task", &cur_task); v->Visit("verbose", &verbose); } - // Search for a task + /*! + * \brief Do schedule search for a task. + * \param task The target search task. + * \param n_trials Total schedules to be tried during this search. + * \param early_stopping Early stop if no better schedule is found. + * \param num_measure_per_round Max measure batch in one search round. + * \param verbose Verbose level. (0 means silent) + * \param measurer A ProgramMeasurer which packs Builder & Runner inside. + * \param pre_search_callbacks SearchCallback to be called before schedule search. + * \return The best state get. + */ virtual State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, + int early_stopping, int num_measure_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) = 0; - // Continue search one round for a task. - // This is used in the task scheduler for searching for multiple tasks together. - virtual std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; - - // Run a list of callback functions + /*! + * \brief Call SearchCallback with the current SearchPolicyNode.u + * \param callbacks SearchCallback to be called. + */ void RunCallbacks(const Array& callbacks); static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); protected: - // The set of the already measured states. - // We store the string format for redundancy check + /*! + * \brief The set of already measured states. + * We store the string format for redundancy check. + */ std::unordered_set measured_states_set_; - // The array of already measured states. + /*! \brief The array of already measured states. */ std::vector measured_states_vector_; - // The throughputs of already measured states + /*! \brief The throughputs of already measured states */ std::vector measured_states_throughputs_; }; -TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); + +/*! + * \brief Managed reference to SearchPolicyNode. + * \sa SearchPolicyNode + */ +class SearchPolicy : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 6be4773fe780..4ef07819bbef 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -19,7 +19,7 @@ /*! * \file ansor/search_task.cc - * \brief Meta information and hardware parameters for a search task + * \brief Meta information and hardware parameters for a search task. */ #include "search_task.h" @@ -51,7 +51,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( const Target& target, const Target& target_host) { if (target->target_name == "llvm") { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), - 32, 64, 16, 64); + 64, 64, 64, 64); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 0f270d105d73..5c02db5afddf 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -19,7 +19,7 @@ /*! * \file ansor/search_task.h - * \brief Meta information and hardware parameters for a search task + * \brief Meta information and hardware parameters for a search task. */ #ifndef TVM_ANSOR_SEARCH_TASK_H_ @@ -37,22 +37,28 @@ class HardwareParams; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { public: - // The number of cores + /*! \brief The number of cores. */ int num_cores; - // The width of vector units in bytes + /*! \brief The width of vector units in bytes. */ int vector_unit_bytes; - // The size of cache line in bytes + /*! \brief The size of cache line in bytes. */ int cache_line_bytes; - // The max length of an axis to be unrolled or vectorized + /*! \brief The max length of an axis to be unrolled or vectorized. */ int max_unroll_vec; - // The max split factor for the innermost tile + /*! \brief The max split factor for the innermost tile. */ int max_innermost_split_factor; // Limitation params for GPU + + /*! \brief The max shared memory per block. */ int max_shared_memory_per_block{INT32_MAX}; + /*! \brief The max register memory per block. */ int max_registers_per_block{INT32_MAX}; + /*! \brief The max threads per block. */ int max_threads_per_block{INT32_MAX}; + /*! \brief The max vthread extent. */ int max_vthread_extent{INT32_MAX}; + /*! \brief The thread numbers of a warp. */ int warp_size{INT32_MAX}; void VisitAttrs(tvm::AttrVisitor* v) { @@ -61,7 +67,6 @@ class HardwareParamsNode : public Object { v->Visit("cache_line_bytes", &cache_line_bytes); v->Visit("max_unroll_vec", &max_unroll_vec); v->Visit("max_innermost_split_factor", &max_innermost_split_factor); - v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); v->Visit("max_registers_per_block", &max_registers_per_block); v->Visit("max_threads_per_block", &max_threads_per_block); @@ -69,6 +74,12 @@ class HardwareParamsNode : public Object { v->Visit("warp_size", &warp_size); } + /*! + * \brief Get the default hardware params. + * \param target A `tvm.target`. + * \param target_host A `tvm.target` for host device. + * \return A HardwareParams object. + */ static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); @@ -82,6 +93,14 @@ class HardwareParamsNode : public Object { */ class HardwareParams : public ObjectRef { public: + /*! + * \brief The constructor. + * \param num_cores The number of cores. + * \param vector_unit_bytes The width of vector units in bytes. + * \param cache_line_bytes The size of cache line in bytes. + * \param max_unroll_vec The max length of an axis to be unrolled or vectorized. + * \param max_innermost_split_factor The max split factor for the innermost tile. + */ HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, int max_innermost_split_factor); @@ -92,10 +111,15 @@ class HardwareParams : public ObjectRef { /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { public: + /*! \brief The ComputeDAG for target compute declaration. */ ComputeDAG compute_dag; + /*! \brief The workload key for target compute declaration. */ std::string workload_key; + /*! \brief The target device of this search task. */ Target target; + /*! \brief The target host device of this search task. */ Target target_host; + /*! \brief Hardware parameters used in this search task. */ HardwareParams hardware_params; void VisitAttrs(tvm::AttrVisitor* v) { @@ -116,6 +140,14 @@ class SearchTaskNode : public Object { */ class SearchTask : public ObjectRef { public: + /*! + * \brief The constructor. + * \param compute_dag The ComputeDAG for target compute declaration. + * \param workload_key The workload key for target compute declaration. + * \param target The target device of this search task. + * \param target_host The target host device of this search task. + * \param hardware_params Hardware parameters used in this search task. + */ SearchTask(ComputeDAG compute_dag, std::string workload_key, Target target, Target target_host, HardwareParams hardware_params); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index d4e4d645d192..d62847ef2248 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -19,7 +19,7 @@ /*! * \file ansor/serialization.cc - * \brief Json serialization format for dumping and loading tuning records + * \brief Json serialization format for dumping and loading tuning records. */ #include @@ -52,14 +52,14 @@ inline std::vector& IntArrayToVector(std::vector* out, } template <> -struct Handler > { +struct Handler<::tvm::Array<::tvm::ansor::Stage> > { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Stage> & data) { + const ::tvm::Array<::tvm::ansor::Stage> & data) { writer->BeginArray(false); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Stage> * data) { + ::tvm::Array<::tvm::ansor::Stage> * data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); @@ -328,7 +328,7 @@ void ReadMeasureRecord(const std::string& str, } } -void LogToFileNode::callback(const SearchPolicy& policy, +void LogToFileNode::Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) { std::ofstream ofs(filename, std::ofstream::app); diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index a1559c62062b..c4e4d1c334fb 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -19,7 +19,7 @@ /*! * \file ansor/serialization.h - * \brief Json serialization format for dumping and loading tuning records + * \brief Json serialization format for dumping and loading tuning records. */ #ifndef TVM_ANSOR_SERIALIZATION_H_ @@ -36,10 +36,10 @@ namespace ansor { /*! \brief Callback for logging the input and results of measurements to file */ class LogToFileNode : public MeasureCallbackNode { public: + /*! \brief File name for this callback to write log to. */ std::string filename; - /*! \brief Log measure pairs to file. This is called by the search policy */ - void callback(const SearchPolicy& policy, + void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; @@ -53,6 +53,10 @@ class LogToFileNode : public MeasureCallbackNode { */ class LogToFile : public MeasureCallback { public: + /*! + * \brief The constructor. + * \param filename File name for this callback to write log. + */ explicit LogToFile(std::string filename); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); @@ -61,18 +65,26 @@ class LogToFile : public MeasureCallback { /*! \brief Log reader to load step logs from a target file.*/ class LogReaderNode : public Object { public: + /*! \brief File name for this reader to load log from. */ std::string filename; + /*! \brief The reading file stream. */ std::ifstream infile; ~LogReaderNode(); - /*! \brief Read next line in the log file - * \return Whether the read is successful */ + /*! + * \brief Read next line in the log file. + * \param inp A pointer to MeasureInputNode, this is used as output. + * \param res A pointer to MeasureResultNode, this is used as output. + * \return Whether the read is successful. */ bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); - /*! \brief Read multiple lines from the log file - * \param max_size The maximum number of lines. -1 means read all lines - * \param skip_size Skip the first n lines */ + /*! + * \brief Read multiple lines from the log file. + * \param max_size The maximum number of lines. -1 means read all lines. + * \param skip_size Skip the first n lines + * \return The MeasureInputs and MeasureResults loaded from the log file. + */ std::pair, Array > ReadLines( int max_size = -1, int skip_size = 0); @@ -80,6 +92,7 @@ class LogReaderNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); private: + /*! \brief A string object to store the next line. */ std::string cur_line; }; @@ -89,17 +102,32 @@ class LogReaderNode : public Object { */ class LogReader : public ObjectRef { public: + /*! + * \brief The constructor. + * \param filename File name for this callback to write log. + */ explicit LogReader(std::string filename); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogReader, ObjectRef, LogReaderNode); }; -/*! \brief Write measure records to an output stream */ +/*! + * \brief Write measure records to an output stream. + * \param os A pointer to output stream. + * \param inputs The target MeasureInputs to be written. + * \param results The target MeasureResults to be written. + */ void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results); -/*! \brief Read one measure record from a string */ +/*! + * \brief Read one measure record from a string. + * \param str The target record string to be extract. + * \param inp A pointer to MeasureInputNode, this is used as output. + * \param res A pointer to MeasureResultNode, this is used as output. + * \param log_version A pointer to log version string. + */ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 1bcea3f690c9..5d41054975b4 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -20,14 +20,16 @@ /*! * \file ansor/transform_step.cc * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. - * - * See the note in transform_step.h on how to add a new step */ #include "transform_step.h" + #include #include + #include + +#include "loop_state.h" #include "utils.h" namespace tvm { diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 8eff6a4e7536..8064c59b4c82 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -32,30 +32,74 @@ * CopyOnWrite style * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` - * in `serialization.cc` - * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) - * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test + * in `serialization.cc`. + * 6. Add hash support in `struct hash<::tvm::ansor::Step>`. (search for this function in this file) + * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test. */ #ifndef TVM_ANSOR_TRANSFORM_STEP_H_ #define TVM_ANSOR_TRANSFORM_STEP_H_ #include +#include +#include + #include #include -#include "loop_state.h" +#include + +#include "utils.h" namespace tvm { namespace ansor { -using namespace tvm::tir; +typedef std::unordered_map, ObjectHash, ObjectEqual> + StageToAxesMap; + +class Step; + +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + /*! \brief The index of the target stage. */ + int stage_id; + + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \param schedule A pointer to `te::Schedule`. + * \param transform_steps Transform steps of the target state. + * \return Python schedule code. + */ + virtual std::string PrintAsPythonAPI(std::vector* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; + +class Step : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); +}; /*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode: public StepNode { public: - std::vector after_ids; // The iterator ids after reorder. - // This array should specify the order of all iterators. + /*! + * \brief The iterator ids after reorder. + * This array should specify the order of all iterators. + */ + std::vector after_ids; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + */ void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -74,23 +118,42 @@ class ReorderStepNode: public StepNode { */ class ReorderStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the target stage. + * \param after_ids The index of the iterators after reorder. + */ ReorderStep(int stage_id, const std::vector& after_ids); TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; -/*! \brief Split step that corresponds to te::Stage::split with additional - * support of multiple-level of factors */ +/*! + * \brief Split step that corresponds to te::Stage::split with additional + * support of multiple-level of factors + */ class SplitStepNode: public StepNode { public: - int iter_id; // The id of the iter to split - PrimExpr extent; // the extent length of the axis to split - std::vector lengths; // The split factors - bool inner_to_outer; // If true, the `lengths` denote the lengths of - // iterators from inner level to outer level + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The extent length of the axis to split. */ + PrimExpr extent; + /*! \brief The split factors. */ + std::vector lengths; + /*! + * \brief If true, the `lengths` denote the lengths of iterators + * from inner level to outer level + */ + bool inner_to_outer; - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \return The iterator results after split. + */ + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -107,6 +170,13 @@ class SplitStepNode: public StepNode { */ class SplitStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the target stage. + * \param extent The index of the target iterator. + * \param lengths The extent length of the axis to split. + * \param inner_to_outer The split direction. + */ SplitStep(int stage_id, int iter_id, PrimExpr extent, const std::vector& lengths, bool inner_to_outer); @@ -117,10 +187,17 @@ class SplitStep : public Step { /*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: - std::vector fused_ids; // The ids of iterators to fuse + /*! \brief The ids of iterators to fuse. */ + std::vector fused_ids; - IterVar ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; + /*! + * \brief Apply the current state to tvm.schedule + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \return The iterator result after fuse. + */ + tir::IterVar ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -137,6 +214,11 @@ class FuseStepNode: public StepNode { */ class FuseStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the target stage. + * \param fused_ids The index of the target iterators to be fused. + */ FuseStep(int stage_id, const std::vector& fused_ids); TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); @@ -148,6 +230,7 @@ class FuseStep : public Step { // Hash and equal function for Step namespace std { +/*! \brief The hash function of each transform step. */ template <> struct hash<::tvm::ansor::Step> { std::size_t operator()(const ::tvm::ansor::Step& step) const { diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc index ed41321c4639..93a8e2257604 100644 --- a/src/ansor/utils.cc +++ b/src/ansor/utils.cc @@ -19,7 +19,7 @@ /*! * \file ansor/utils.cc - * \brief Common utilities + * \brief Common utilities. */ #include "utils.h" diff --git a/src/ansor/utils.h b/src/ansor/utils.h index a0a00ef947cd..cad27d51ba7e 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -19,7 +19,7 @@ /*! * \file ansor/utils.h - * \brief Common utilities + * \brief Common utilities. */ #ifndef TVM_ANSOR_UTILS_H_ @@ -81,13 +81,6 @@ struct hash > { namespace tvm { namespace ansor { -/*! \brief Macro to make it easy to define mutable object ref type given node */ -#define TVM_DEFINE_MUTABLE_OBJECT_REF(TypeName, ObjectName) \ - class TypeName : public ObjectRef { \ - public: \ - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ - }; \ - /********** Utilities for std::vector, std::set, std::string **********/ /*! \brief Get the first appearance index of elements in a vector */ template @@ -171,6 +164,20 @@ inline int64_t AxisLengthProd(const Array& axes) { return ret; } +/*! + * \brief Clean the name of an iterator to make it valid in python code. + * \param str The original name. + * \return The cleaned name. + */ +inline std::string CleanName(const std::string& str) { + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + /*! \brief An empty output stream */ class NullStream : public std::ostream { public: From 9c35e50cd5a5511a7c03b7b4860a12d2ff26c4cd Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 28 Jun 2020 20:48:53 +0800 Subject: [PATCH 45/78] Headfile update & Python doc update --- python/tvm/ansor/auto_schedule.py | 28 ++++++--- python/tvm/ansor/compute_dag.py | 16 ++--- python/tvm/ansor/loop_state.py | 9 +++ python/tvm/ansor/measure.py | 58 +++++++++++++------ python/tvm/ansor/serialization.py | 37 ++++++++---- python/tvm/ansor/utils.py | 7 ++- src/ansor/auto_schedule.cc | 2 + src/ansor/auto_schedule.h | 13 +++-- src/ansor/compute_dag.cc | 4 +- src/ansor/compute_dag.h | 5 +- src/ansor/loop_state.cc | 14 +++-- src/ansor/loop_state.h | 3 +- src/ansor/measure.h | 2 + src/ansor/search_policy/empty_policy.cc | 2 + src/ansor/search_policy/search_policy.cc | 4 +- src/ansor/search_policy/search_policy.h | 26 +++++++-- src/ansor/search_task.cc | 2 + src/ansor/search_task.h | 2 + src/ansor/serialization.cc | 11 ++-- src/ansor/serialization.h | 1 + src/ansor/transform_step.h | 7 ++- .../unittest/test_ansor_search_policy.py | 2 +- 22 files changed, 177 insertions(+), 78 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 750c3743c0eb..0e88b8183686 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -33,10 +33,15 @@ class HardwareParams(Object): Parameters ---------- num_cores : int + The number of device cores. vector_unit_bytes : int + The width of vector units in bytes. cache_line_bytes : int + The size of cache line in bytes. max_unroll_vec : int + The max length of an axis to be unrolled or vectorized. max_innermost_split_factor : int + The max split factor for the innermost tile. """ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, max_innermost_split_factor): @@ -52,10 +57,15 @@ class SearchTask(Object): Parameters ---------- dag : ComputeDAG + The ComputeDAG for target compute declaration. workload_key : str + The workload key for target compute declaration. target : tvm.target.Target + The target device of this search task. target_host : tvm.target.Target + The target host device of this search task. hardware_params : HardwareParams + Hardware parameters used in this search task. """ def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None): @@ -78,11 +88,6 @@ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) -@tvm._ffi.register_object("ansor.SearchCallback") -class SearchCallback(Object): - """ Callback function before or after search process """ - - @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): """ The options for tuning @@ -108,8 +113,8 @@ class TuneOption(Object): pre_search_callbacks: List[SearchCallback] Callback functions called before the search process Candidates: - - ansor.PreloadMeasuredStates - - ansor.PreloadCustomSketchRule + - ansor.PreloadMeasuredStates(will be added later) + - ansor.PreloadCustomSketchRule(will be added later) """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_round=64, verbose=1, builder='local', runner='local', measure_callbacks=None, @@ -148,16 +153,21 @@ def auto_schedule(workload, target=None, Parameters ---------- workload : Union[SearchTask, str] + The target search task or workload key. target : Target + The target device of this schedule search. target_host : Target = None + The target host device of this schedule search. search_policy : Union[SearchPolicy, str] + The search policy to be used for schedule search. hardware_params : HardwareParams + The hardware parameters of this schedule search. tune_option : TuneOption + Tuning and measurement options. Returns ------- - sch : tvm.Schedule - tensors : List[Tensor] + A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build` """ if isinstance(search_policy, str): if search_policy == 'default': diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index e57fbbc08843..41ba40a7f481 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -26,17 +26,18 @@ @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): """ - Computation declaration graph + Computation declaration graph. Parameters ---------- tensors : List[Tensor] + `Tensor`s for a compute declaration. """ def __init__(self, tensors): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) def get_init_state(self): - """ Get init state of this ComputeDAG + """ Get init state of this ComputeDAG. Returns ------- @@ -50,13 +51,12 @@ def apply_steps_from_state(self, state): Parameters ---------- - state : StateObject - layout_rewrite_level : LayoutRewriteLevel + state : StateObject or State + The target state to be applied to TVM schedule. Returns ------- - sch : Schedule - args : List[Tensor] + A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build` """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) @@ -67,7 +67,8 @@ def print_python_code_from_state(self, state): Parameters ---------- - state : StateObject + state : StateObject or State + The target state to be applied to TVM schedule. Returns ------- @@ -83,6 +84,7 @@ def infer_bound_from_state(self, state): Parameters ---------- state : StateObject + The target state to be applied to TVM schedule. Returns ------- diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 791ba2e74ad4..d08d27300f16 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -73,6 +73,13 @@ class State: A state in the search process. It consists of the current loop structure and the history steps to reach this state. + Parameters + ---------- + state_object : StateObject + The target StateObject, corresponding to C++ internal State object. + dag : ComputeDAG + The original target ComputeDAG of this State. + Notes ----- This is a wrapper class of StateObject to deal with copy-on-write property @@ -192,6 +199,8 @@ def _clear_cache(self): self.stages_cache = None def copy(self): + """ Do deep copy of this State. + """ state = State(self.state_object, self.compute_dag) state.stage_id_map = self.stage_id_map.copy() return state diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 66d1eb74fac9..d99fb3f8995a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -57,7 +57,9 @@ class MeasureInput(Object): Parameters ---------- task : SearchTask + The target SearchTask. state : State + The current State to be measured. """ def __init__(self, task, state): @@ -66,14 +68,20 @@ def __init__(self, task, state): @tvm._ffi.register_object("ansor.BuildResult") class BuildResult(Object): - """ + """ Store the input of a build. + Parameters ---------- filename : Str + The filename of built binary file. args : List[Tensor] + The arguments. error_no : Int + The error code. error_msg : Str + The error message if there is any error. time_cost : Float + The time cost of build. """ def __init__(self, filename, args, error_no, error_msg, time_cost): @@ -88,10 +96,15 @@ class MeasureResult(Object): Parameters ---------- costs : List[Float] + The time costs of execution. error_no : Int + The error code. error_msg : Str + The error message if there is any error. all_cost : Float + The time cost of build and run. timestamp : Float + The time stamps of this measurement. """ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @@ -102,14 +115,15 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.Builder") class Builder(Object): - """ Base class of Builder - """ + """ Base class of Builder """ def build(self, measure_inputs, verbose=1): """ Parameters ---------- measure_inputs : List[MeasureInput] + A List of MeasureInput. verbost : Int + Verbosity level. (0 means silent) Returns ------- @@ -120,14 +134,15 @@ def build(self, measure_inputs, verbose=1): @tvm._ffi.register_object("ansor.Runner") class Runner(Object): - """ Base class of Runner - """ + """ Base class of Runner """ def run(self, measure_inputs, build_results, verbose=1): """ Parameters ---------- measure_inputs : List[MeasureInput] + A List of MeasureInput. build_results : List[BuildResult] + A List of BuildResult to be ran. Returns ------- @@ -138,12 +153,16 @@ def run(self, measure_inputs, build_results, verbose=1): @tvm._ffi.register_object("ansor.LocalBuilder") class LocalBuilder(Builder): - """ + """ LocalBuilder use local CPU cores to build programs in parallel. + Parameters ---------- timeout : Int + The timeout limit for each build. n_parallel : Int + Number of threads used to build in parallel. build_func : Str + The name of registered build function. """ def __init__(self, @@ -156,14 +175,20 @@ def __init__(self, @tvm._ffi.register_object("ansor.LocalRunner") class LocalRunner(Runner): - """ + """ LocalRunner that uses local CPU/GPU to measures the time cost of programs. + Parameters ---------- timeout : Int + The timeout limit for each run. number : Int + Number of measure times. repeat : Int + Number of repeat times in each measure. min_repeat_ms : Int + The minimum duration of one repeat in milliseconds. cooldown_interval : Float + The cool down interval between two measurements. """ def __init__(self, @@ -192,7 +217,7 @@ class MeasureErrorNo(object): def make_error_msg(): - """Get the error message from traceback""" + """ Get the error message from traceback """ error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ @@ -205,8 +230,7 @@ def make_error_msg(): def local_build_worker(index): - """ Local builder function - """ + """ Local builder function """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool measure_inputs, build_func, timeout, verbose = global_build_arguments @@ -267,10 +291,8 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") -def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, - build_func: str, verbose: int): - """ Local builder build function - """ +def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): + """ Local builder build function """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global global_build_arguments @@ -289,11 +311,9 @@ def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: return results @tvm._ffi.register_func("ansor.local_runner.run") -def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], - timeout: float, number: int, repeat: int, min_repeat_ms: int, - cooldown_interval: float, verbose: int): - """ ... - """ +def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, + verbose): + """ Local runner run function """ MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log def timed_func(inp, build_res): diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 1bd9d8cf64e6..deae5df68229 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -34,6 +34,7 @@ class LogToFile(MeasureCallback): Parameters ---------- filename : Str + File name for this callback to write log to. """ def __init__(self, filename="ansor_tuning.json"): @@ -48,6 +49,7 @@ class LogReader(Object): Parameters ---------- filename : Str + File name for this reader to load log from. """ def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) @@ -66,20 +68,31 @@ def __iter__(self): def load_from_file(filename: str): - """Load measurement records from a file""" + """ + Load measurement records from a file. + + Parameters + ---------- + filename : Str + File name to load log from. + """ return zip(*LogReader(filename).read_lines()) def write_measure_records_to_file(filename, inputs, results): - """Write(append) measure records to file""" - _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) - - -def get_states_from_measure_inputs(inputs, task): - """Get states from measure inputs""" - state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task) - return [State(s, task.compute_dag) for s in state_objects] + """ + Write(append) measure records to file + Parameters + ---------- + filename : Str + File name to write log to. + inputs: List[MeasureInputs] + The target MeasureInputs to be written. + results: List[MeasureResults] + The target MeasureResults to be written. + """ + _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) def best_measure_pair_in_file(filename, workload_key=None, target=None): """ Return the best measurement pair form a log file @@ -87,13 +100,15 @@ def best_measure_pair_in_file(filename, workload_key=None, target=None): Parameters ---------- filename : Str + File name to load log from. workload_key : Str + The workload key of the target compute declaration. target : Str + The target device. Returns ------- - inp : MeasureInput - res : MeasureResult + The best state from this log fine in form (MeasureInput, MeasureResult). """ log_reader = LogReader(filename) best_cost = 1e30 diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index dd701234aef9..e83a9deb3f0a 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -37,16 +37,17 @@ def get_func_name(func): - """Get name of a function + """Get name of a function. Parameters ---------- func: Function - The function + The target function. + Returns ------- name: str - The name + The function name. """ return func.func_name if hasattr(func, 'func_name') else func.__name__ diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index a2e3b7c11f4e..3e9192000dfd 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -23,7 +23,9 @@ */ #include "auto_schedule.h" + #include + #include #include diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 9df15519b419..6aecbaf591c6 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -19,7 +19,9 @@ /*! * \file ansor/auto_schedule.h - * \brief The user interface of the Ansor auto-scheduler. + * \brief The user interface of the Ansor auto-scheduler. This is the entry structure to get + * schedule search requirements from upper level (Python API), and returns a high performance + * schedule after search process. */ #ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ @@ -27,6 +29,7 @@ #include #include + #include "measure.h" #include "search_policy/search_policy.h" @@ -106,12 +109,12 @@ std::pair > AutoSchedule( /*! * \brief Auto schedule search for a given compute declaration, by workload key. * \param workload_key The target workload key. - * \param target A `tvm::target`. - * \param target_host A `tvm::target` for host device. + * \param target The target device of this schedule search. + * \param target_host The target host device of this schedule search. * \param search_policy The search policy to be used for schedule search. - * \param hardware_params Hardware parameters. + * \param hardware_params The hardware parameters of this schedule search. * \param tune_option Tuning and measurement options. - * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build` + * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. */ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index b9a83733c116..0e6397b87c75 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -389,8 +389,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } } - pstate->stages.Set(i, Stage(stage->op, stage->op_type, std::move(new_iters), - stage->compute_at, stage->attrs)); + pstate->stages[i] = Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs); } } diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 0d1473126ad6..b28c5d782fb3 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -20,6 +20,8 @@ /*! * \file ansor/compute_dag.h * \brief Compute declaration graph and its related analysis tools. + * ComputeDAG is responsible for the interaction with the original TVM schedule system, to apply + * state to a runable TVM schedule or provide the schedule Python code. */ #ifndef TVM_ANSOR_COMPUTE_DAG_H_ @@ -31,7 +33,6 @@ #include #include #include -#include #include namespace tvm { @@ -49,7 +50,7 @@ typedef std::unordered_map, ObjectHash */ void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); -/*! \brief Computation declaration graph */ +/*! \brief Computation declaration graph. */ class ComputeDAGNode : public Object { public: /*! \brief Input and output tensors. */ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 3843e7954500..098a66e78da5 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -24,8 +24,10 @@ */ #include "loop_state.h" + #include #include + #include "transform_step.h" #include "utils.h" @@ -169,9 +171,9 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, Stage( + pstate->stages[step->stage_id] = Stage( stage->op, stage->op_type, std::move(iters), stage->compute_at, - stage->attrs)); + stage->attrs); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -235,9 +237,9 @@ std::vector State::DoSplitStepCommon( stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, Stage( + pstate->stages[stage_id] = Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->attrs)); + stage->attrs); return outs; } @@ -295,9 +297,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, Stage( + pstate->stages[stage_id] = Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->attrs)); + stage->attrs); return new_it; } diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 723f0b78fb04..da03b3474e2d 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -241,7 +241,7 @@ class Stage : public ObjectRef { class StateNode: public Object { public: /*! \brief Current stages and loop structures. */ - Array stages; + std::vector stages; /*! \brief History transformation steps. */ std::vector transform_steps; /*! \brief Indicate whether this state has unfilled tile sizes. */ @@ -254,7 +254,6 @@ class StateNode: public Object { ComputeDAG task_dag; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("stages", &stages); v->Visit("complete", &complete); v->Visit("task_dag", &task_dag); } diff --git a/src/ansor/measure.h b/src/ansor/measure.h index e46c42c8312c..a593182d0cba 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -20,6 +20,7 @@ /*! * \file ansor/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. + * MeasureInput -> BuildeResult -> MeasureResult */ #ifndef TVM_ANSOR_MEASURE_H_ @@ -29,6 +30,7 @@ #include #include #include + #include "search_task.h" #include "loop_state.h" diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index fe880b6cf262..0f5b5b1e9ae1 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -26,6 +26,8 @@ #include +#include "../measure.h" + namespace tvm { namespace ansor { diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 8786f67edb22..bd664ac7da91 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,8 +23,10 @@ */ #include "search_policy.h" + #include -#include "../serialization.h" + +// #include "../serialization.h" namespace tvm { namespace ansor { diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index cc54822e925f..4324d3b785aa 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -19,24 +19,42 @@ /*! * \file ansor/search_policy/search_policy.h - * \brief The base class for search policy. + * \brief The base class for search policy, including the abstract defination of search policy and + * some other supporting structures. + * + * \note Adding a new search policy. + * In design, there's no need for users to implement their own search policy, our formal search + * policy(will be brought later) should be enough to cover auto schedule generation for different + * ops/subgraphs, and in the meantime, a custom rule mechanism will be provided to enable + * user-defined template search. (which should play a same role as the current AutoTVM template) + * This guide is to help understand it better and incase some advanced users have special + * requirements. + * 1. The only funcion that must be implemented is Search(), the design principe for it is to be + * the entry of starting a schedule search and returns the best schedule get. + * 2. Imformations about the target ops/subgraphs can be acquired from SearchTask, this structure + * also contains HardwareParams which can be used to limit the search space. (For exp. limit the + * max vectorize size depending on the vector unit weight of a specific device) + * 3. SearchCallback provides more flexibility to do extra affairs during the search process. + * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get + * during the search process. */ #ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ -#include "../search_task.h" #include + #include #include #include #include -#include "../measure.h" + +#include "../search_task.h" namespace tvm { namespace ansor { -class SearchPolicyNode; +class ProgramMeasurer; class SearchPolicyNode; /*! * \brief Callback function to be called by the search process. diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 4ef07819bbef..5248a3a0342a 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -23,9 +23,11 @@ */ #include "search_task.h" + #include #include #include + #include #include diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 5c02db5afddf..16601bc09516 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -26,7 +26,9 @@ #define TVM_ANSOR_SEARCH_TASK_H_ #include + #include + #include "compute_dag.h" namespace tvm { diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index d62847ef2248..8648debc17b7 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -22,15 +22,18 @@ * \brief Json serialization format for dumping and loading tuning records. */ +#include "serialization.h" + #include #include + #include #include #include #include #include #include -#include "serialization.h" + #include "loop_state.h" #include "transform_step.h" #include "utils.h" @@ -52,14 +55,14 @@ inline std::vector& IntArrayToVector(std::vector* out, } template <> -struct Handler<::tvm::Array<::tvm::ansor::Stage> > { +struct Handler > { inline static void Write(dmlc::JSONWriter* writer, - const ::tvm::Array<::tvm::ansor::Stage> & data) { + const std::vector<::tvm::ansor::Stage> & data) { writer->BeginArray(false); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - ::tvm::Array<::tvm::ansor::Stage> * data) { + std::vector<::tvm::ansor::Stage> * data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index c4e4d1c334fb..4dfb72b02b7b 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -28,6 +28,7 @@ #include #include #include + #include "measure.h" namespace tvm { diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 8064c59b4c82..2c408289bcc8 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -20,9 +20,12 @@ /*! * \file ansor/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. + * The implementation of each step consists of 2 parts: + * - transform_step.cc: How each step interact with TVM system + * - loop_state.cc: How each step reflect on LoopState * - * \note How to add a new transform step. - * Take fuse for example: + * \note Adding a new transform step. + * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction * function `FuseStep::FuseStep(...)` in `transform_steps.cc` * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index b701dad6d8c0..20d93b8681e7 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -44,7 +44,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.EmptyPolicy() # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) - tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, + tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, verbose=0, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) sch, args = ansor.auto_schedule(task, search_policy=search_policy, From a015051aa869efe2b6c6cba8090a83f7da817e02 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 29 Jun 2020 14:40:58 +0800 Subject: [PATCH 46/78] clang-format fix --- src/ansor/auto_schedule.cc | 59 +++++++++++------------- src/ansor/auto_schedule.h | 7 ++- src/ansor/compute_dag.cc | 46 +++++++++--------- src/ansor/compute_dag.h | 6 +-- src/ansor/loop_state.cc | 19 ++++---- src/ansor/loop_state.h | 5 +- src/ansor/measure.cc | 3 +- src/ansor/measure.h | 33 ++++++------- src/ansor/search_policy/empty_policy.cc | 9 ++-- src/ansor/search_policy/empty_policy.h | 7 ++- src/ansor/search_policy/search_policy.cc | 14 ++---- src/ansor/search_policy/search_policy.h | 12 ++--- src/ansor/search_task.cc | 3 +- src/ansor/serialization.cc | 19 +++----- src/ansor/serialization.h | 5 +- src/ansor/transform_step.cc | 2 +- src/ansor/transform_step.h | 58 ++++++++++------------- 17 files changed, 137 insertions(+), 170 deletions(-) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 3e9192000dfd..d409988a8007 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -34,9 +34,8 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(TuneOptionNode); -TuneOption::TuneOption(int n_trials, int early_stopping, - int num_measure_per_round, int verbose, Builder builder, - Runner runner, Array measure_callbacks, +TuneOption::TuneOption(int n_trials, int early_stopping, int num_measure_per_round, int verbose, + Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks) { auto node = make_object(); node->n_trials = n_trials; @@ -58,10 +57,9 @@ std::pair > AutoSchedule(SearchTask task, tune_option->measure_callbacks, tune_option->verbose); // Search for the best schedule - State state = search_policy->Search( - task, tune_option->n_trials, tune_option->early_stopping, - tune_option->num_measure_per_round, tune_option->verbose, measurer, - tune_option->pre_search_callbacks); + State state = search_policy->Search(task, tune_option->n_trials, tune_option->early_stopping, + tune_option->num_measure_per_round, tune_option->verbose, + measurer, tune_option->pre_search_callbacks); return task->compute_dag.ApplySteps(state->transform_steps); } @@ -80,36 +78,31 @@ std::pair > AutoSchedule( } TVM_REGISTER_GLOBAL("ansor.TuneOption") -.set_body_typed([](int n_trials, int early_stopping, - int num_measure_per_round, int verbose, Builder builder, - Runner runner, Array measure_callbacks, - Array pre_search_callbacks) { - return TuneOption(n_trials, early_stopping, num_measure_per_round, verbose, - builder, runner, measure_callbacks, pre_search_callbacks); -}); + .set_body_typed([](int n_trials, int early_stopping, int num_measure_per_round, int verbose, + Builder builder, Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { + return TuneOption(n_trials, early_stopping, num_measure_per_round, verbose, builder, runner, + measure_callbacks, pre_search_callbacks); + }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") -.set_body_typed([](SearchTask task, SearchPolicy search_policy, - TuneOption tune_option) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); - - return Array{sch, return_tensors}; -}); + .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); + return Array{sch, return_tensors}; + }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") -.set_body_typed([](std::string workload_key, Target target, - Target target_host, SearchPolicy search_policy, - HardwareParams hardware_params, TuneOption tune_option) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = - AutoSchedule(workload_key, target, target_host, search_policy, - hardware_params, tune_option); - - return Array{sch, return_tensors}; -}); + .set_body_typed([](std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = AutoSchedule(workload_key, target, target_host, search_policy, + hardware_params, tune_option); + return Array{sch, return_tensors}; + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 6aecbaf591c6..a7bfb8449537 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -27,8 +27,8 @@ #ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ #define TVM_ANSOR_AUTO_SCHEDULE_H_ -#include #include +#include #include "measure.h" #include "search_policy/search_policy.h" @@ -88,9 +88,8 @@ class TuneOption : public ObjectRef { * \param measure_callbacks MeasureCallback functions to be called after each measure batch. * \param pre_search_callbacks SearchCallback functions to be called before schedule search. */ - TuneOption(int n_trials, int early_stopping, int num_measure_per_round, - int verbose, Builder builder, Runner runner, - Array measure_callbacks, + TuneOption(int n_trials, int early_stopping, int num_measure_per_round, int verbose, + Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks); TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 0e6397b87c75..c8fd7f8f6d37 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -24,19 +24,19 @@ #include "compute_dag.h" +#include +#include #include #include -#include #include -#include +#include +#include +#include +#include #include #include -#include -#include #include -#include -#include #include #include "loop_state.h" @@ -49,7 +49,8 @@ using namespace tvm::tir; TVM_REGISTER_NODE_TYPE(ComputeDAGNode); -void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { +// Update stage to axis mapping +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { if (auto pop = stage->op.as()) { std::vector& axes = (*stage_to_axes)[stage]; axes.clear(); @@ -389,8 +390,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } } - pstate->stages[i] = Stage(stage->op, stage->op_type, std::move(new_iters), - stage->compute_at, stage->attrs); + pstate->stages[i] = + Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); } } @@ -505,25 +506,22 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") .set_body_method(&ComputeDAG::GetInitState); TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") -.set_body([](TVMArgs args, TVMRetValue *ret) { - ComputeDAG dag = args[0]; - State state = args[1]; - - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); - *ret = Array{sch, return_tensors}; -}); + .set_body_typed([](const ComputeDAG& dag, const State& state) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); + return Array{sch, return_tensors}; + }); TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") -.set_body_typed([](const ComputeDAG& dag, const State& state) { - return dag.PrintStepsAsPython(state->transform_steps); -}); + .set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.PrintStepsAsPython(state->transform_steps); + }); TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState") -.set_body_typed([](const ComputeDAG& dag, const State& state) { - return dag.ReplayAndInferBound(state->transform_steps); -}); + .set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.ReplayAndInferBound(state->transform_steps); + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index b28c5d782fb3..25841ff4f268 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -30,10 +30,10 @@ #include #include -#include #include -#include #include +#include +#include namespace tvm { namespace ansor { @@ -87,7 +87,7 @@ class ComputeDAG: public ObjectRef { */ explicit ComputeDAG(const std::string& workload_key); - /*! + /*! * \brief Apply transform steps to the init state of this DAG, and get the * equivalent `tvm::schedule`. * \param transform_steps Transform steps of the target state. diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 098a66e78da5..6d877a21a6a6 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -117,8 +117,8 @@ State::State(const Array& ops) { data_ = std::move(node); } -State::State(const std::vector& stages, - const std::vector& transform_steps, bool complete) { +State::State(const std::vector& stages, const std::vector& transform_steps, + bool complete) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; @@ -171,9 +171,8 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = Stage( - stage->op, stage->op_type, std::move(iters), stage->compute_at, - stage->attrs); + pstate->stages[step->stage_id] = + Stage(stage->op, stage->op_type, std::move(iters), stage->compute_at, stage->attrs); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -237,9 +236,8 @@ std::vector State::DoSplitStepCommon( stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = Stage( - stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->attrs); + pstate->stages[stage_id] = + Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); return outs; } @@ -297,9 +295,8 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = Stage( - stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->attrs); + pstate->stages[stage_id] = + Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); return new_it; } diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index da03b3474e2d..94f3faf71b45 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -42,9 +42,9 @@ #include #include +#include #include #include -#include #include "compute_dag.h" #include "transform_step.h" @@ -279,8 +279,7 @@ class State : public ObjectRef { * \param transform_steps Transform steps of the target state. * \param complete Indicate whether this state has unfilled tile sizes. */ - State(const std::vector& stages, - const std::vector& transform_steps, bool complete); + State(const std::vector& stages, const std::vector& transform_steps, bool complete); /*! * \brief Schedule primitive corresponds to te.reorder. diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 98a439c195cc..b31b993618f4 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -225,8 +225,7 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, StdCout(verbose) << std::fixed << std::setprecision(2) << "===============================================\n" << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " - << best_flops[workload_key] / 1e9 - << "\tresults: " << result_batch[j] << "\n" + << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n" << "===============================================\n" << input_batch[j]->state << "\n"; } diff --git a/src/ansor/measure.h b/src/ansor/measure.h index a593182d0cba..aae332386367 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -28,16 +28,18 @@ #include #include -#include #include +#include -#include "search_task.h" #include "loop_state.h" +#include "search_task.h" namespace tvm { namespace ansor { -class SearchPolicy; class MeasureInput; class MeasureResult; +class SearchPolicy; +class MeasureInput; +class MeasureResult; /*! \brief The error code of one measurement */ enum MeasureErrorNO { @@ -64,7 +66,7 @@ enum MeasureErrorNO { // Inputs and results of one measurement /*! \brief Store the input of a measurement */ -class MeasureInputNode: public Object { +class MeasureInputNode : public Object { public: /*! \brief The search task. */ SearchTask task; @@ -100,7 +102,7 @@ class MeasureInput : public ObjectRef { }; /*! \brief Store the input of a build. */ -class BuildResultNode: public Object { +class BuildResultNode : public Object { public: /*! \brief The filename of built binary file. */ std::string filename; @@ -145,7 +147,7 @@ class BuildResult : public ObjectRef { }; /*! \brief Store the results of a measurement. */ -class MeasureResultNode: public Object { +class MeasureResultNode : public Object { public: /*! \brief The time costs of execution. */ Array costs; @@ -194,7 +196,7 @@ class MeasureResult : public ObjectRef { }; /*! \brief Bass class of measurement callbacks */ -class MeasureCallbackNode: public Object { +class MeasureCallbackNode : public Object { public: /*! * \brief Callback function that will be called on measurement input/result pairs @@ -203,8 +205,7 @@ class MeasureCallbackNode: public Object { * \param inputs An Array of MeasureInput. * \param results An Array of MeasureResult. */ - virtual void Callback(const SearchPolicy& policy, - const Array& inputs, + virtual void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) = 0; static constexpr const char *_type_key = "ansor.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); @@ -222,7 +223,7 @@ class MeasureCallback : public ObjectRef { // Base class for builder and runner /*! \brief Builder that builds the programs */ -class BuilderNode: public Object { +class BuilderNode : public Object { public: /*! \brief The number of tasks to run in parallel */ int n_parallel; @@ -251,7 +252,7 @@ class Builder : public ObjectRef { }; /*! \brief Runner that runs the built programs and measure the time cost. */ -class RunnerNode: public Object { +class RunnerNode : public Object { public: /*! \brief Timeout of a run. */ int timeout; @@ -283,7 +284,7 @@ class Runner : public ObjectRef { // Implementation of various builders and runners /*! \brief LocalBuilder use local CPU cores to build programs in parallel */ -class LocalBuilderNode: public BuilderNode { +class LocalBuilderNode : public BuilderNode { public: /*! \brief Build function. */ std::string build_func; @@ -298,7 +299,7 @@ class LocalBuilderNode: public BuilderNode { * \brief Managed reference to LocalBuilderNode. * \sa LocalBuilderNode */ -class LocalBuilder: public Builder { +class LocalBuilder : public Builder { public: /*! * \brief The constructor. @@ -312,7 +313,7 @@ class LocalBuilder: public Builder { }; /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ -class LocalRunnerNode: public RunnerNode { +class LocalRunnerNode : public RunnerNode { public: /*! \brief Number of measure times. */ int number; @@ -335,7 +336,7 @@ class LocalRunnerNode: public RunnerNode { * \brief Managed reference to LocalRunnerNode. * \sa LocalRunnerNode */ -class LocalRunner: public Runner { +class LocalRunner : public Runner { public: /*! * \brief The constructor. @@ -355,7 +356,7 @@ class LocalRunner: public Runner { /*! * \brief Measurer that measures the time costs of tvm programs * This class combines Builder and Runner, and provides a simpler API */ -class ProgramMeasurerNode: public Object { +class ProgramMeasurerNode : public Object { public: /*! \brief Measured programs counter. */ int ct; diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index 0f5b5b1e9ae1..287b68c7c0ce 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -34,8 +34,8 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, - int num_measure_per_round, int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) { + int num_measure_per_round, int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) { cur_task = task; // Run pre_search_callbacks before the search process @@ -79,8 +79,9 @@ std::vector EmptyPolicyNode::SearchOneRound() { return res; } -TVM_REGISTER_GLOBAL("ansor.EmptyPolicy") -.set_body_typed([]() { return EmptyPolicy(make_object()); }); +TVM_REGISTER_GLOBAL("ansor.EmptyPolicy").set_body_typed([]() { + return EmptyPolicy(make_object()); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index c7cf721d4bc7..01a47a9d4120 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -42,12 +42,11 @@ namespace ansor { */ class EmptyPolicyNode : public SearchPolicyNode { public: - State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_round, + State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) final; - static constexpr const char *_type_key = "ansor.EmptyPolicy"; + static constexpr const char* _type_key = "ansor.EmptyPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); private: @@ -70,4 +69,4 @@ class EmptyPolicy : public SearchPolicy { } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ \ No newline at end of file +#endif // TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index bd664ac7da91..41a0e650fb1e 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -43,19 +43,15 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") -.set_body_typed([](SearchPolicy policy, Array callbacks) { - policy->RunCallbacks(callbacks); -}); + .set_body_typed([](SearchPolicy policy, Array callbacks) { + policy->RunCallbacks(callbacks); + }); TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") -.set_body_typed([](SearchPolicy policy, SearchTask task) { - policy->cur_task = task; -}); + .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; }); TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") -.set_body_typed([](SearchPolicy policy, int verbose) { - policy->verbose = verbose; -}); + .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 4324d3b785aa..8389320eb085 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -21,7 +21,7 @@ * \file ansor/search_policy/search_policy.h * \brief The base class for search policy, including the abstract defination of search policy and * some other supporting structures. - * + * * \note Adding a new search policy. * In design, there's no need for users to implement their own search policy, our formal search * policy(will be brought later) should be enough to cover auto schedule generation for different @@ -44,17 +44,18 @@ #include +#include #include -#include #include -#include +#include #include "../search_task.h" namespace tvm { namespace ansor { -class ProgramMeasurer; class SearchPolicyNode; +class ProgramMeasurer; +class SearchPolicyNode; /*! * \brief Callback function to be called by the search process. @@ -111,8 +112,7 @@ class SearchPolicyNode : public Object { * \param pre_search_callbacks SearchCallback to be called before schedule search. * \return The best state get. */ - virtual State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_round, + virtual State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) = 0; diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 5248a3a0342a..e7ea6eb05e90 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -52,8 +52,7 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, HardwareParams HardwareParamsNode::GetDefaultHardwareParams( const Target& target, const Target& target_host) { if (target->target_name == "llvm") { - return HardwareParams(tvm::runtime::threading::MaxConcurrency(), - 64, 64, 64, 64); + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 64, 64); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 8648debc17b7..b6dafdc80625 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -29,10 +29,10 @@ #include #include -#include #include #include #include +#include #include "loop_state.h" #include "transform_step.h" @@ -55,14 +55,12 @@ inline std::vector& IntArrayToVector(std::vector* out, } template <> -struct Handler > { - inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Stage> & data) { +struct Handler> { + inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Stage>& data) { writer->BeginArray(false); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Stage> * data) { + inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Stage>* data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); @@ -71,8 +69,7 @@ struct Handler > { template <> struct Handler > { - inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Step> & data) { + inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Step>& data) { std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { @@ -117,8 +114,7 @@ struct Handler > { writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Step> * data) { + inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Step>* data) { std::vector int_list; bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; @@ -331,8 +327,7 @@ void ReadMeasureRecord(const std::string& str, } } -void LogToFileNode::Callback(const SearchPolicy& policy, - const Array& inputs, +void LogToFileNode::Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) { std::ofstream ofs(filename, std::ofstream::app); WriteMeasureRecords(&ofs, inputs, results); diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index 4dfb72b02b7b..fbe79270bb70 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -25,8 +25,8 @@ #ifndef TVM_ANSOR_SERIALIZATION_H_ #define TVM_ANSOR_SERIALIZATION_H_ -#include #include +#include #include #include "measure.h" @@ -40,8 +40,7 @@ class LogToFileNode : public MeasureCallbackNode { /*! \brief File name for this callback to write log to. */ std::string filename; - void Callback(const SearchPolicy& policy, - const Array& inputs, + void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; static constexpr const char *_type_key = "ansor.LogToFile"; diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 5d41054975b4..193a11ffd191 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -24,8 +24,8 @@ #include "transform_step.h" -#include #include +#include #include diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 2c408289bcc8..b18b0b5248e5 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -19,8 +19,8 @@ /*! * \file ansor/transform_step.h - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. - * The implementation of each step consists of 2 parts: + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform + * step. The implementation of each step consists of 2 parts: * - transform_step.cc: How each step interact with TVM system * - loop_state.cc: How each step reflect on LoopState * @@ -48,8 +48,8 @@ #include #include -#include #include +#include #include "utils.h" @@ -62,7 +62,7 @@ typedef std::unordered_map, ObjectHash class Step; /*! \brief The base class for a transformation step */ -class StepNode: public Object { +class StepNode : public Object { public: /*! \brief The index of the target stage. */ int stage_id; @@ -76,8 +76,7 @@ class StepNode: public Object { * \return Python schedule code. */ virtual std::string PrintAsPythonAPI(std::vector* stages, - StageToAxesMap* stage_to_axes, - te::Schedule* schedule, + StageToAxesMap* stage_to_axes, te::Schedule* schedule, const std::vector& transform_steps) const = 0; static constexpr const char* _type_key = "ansor.Step"; @@ -90,7 +89,7 @@ class Step : public ObjectRef { }; /*! \brief Reorder step that corresponds to te::Stage::reorder */ -class ReorderStepNode: public StepNode { +class ReorderStepNode : public StepNode { public: /*! * \brief The iterator ids after reorder. @@ -103,12 +102,10 @@ class ReorderStepNode: public StepNode { * \param stages A pointer to `te::Stage` vector. * \param stage_to_axes A pointer to StageToAxesMap. */ - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; + void ApplyToSchedule(std::vector* stages, StageToAxesMap* stage_to_axes) const; - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, + std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, const std::vector& transform_steps) const final; static constexpr const char* _type_key = "ansor.ReorderStep"; @@ -135,7 +132,7 @@ class ReorderStep : public Step { * \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors */ -class SplitStepNode: public StepNode { +class SplitStepNode : public StepNode { public: /*! \brief The id of the iter to split. */ int iter_id; @@ -155,12 +152,11 @@ class SplitStepNode: public StepNode { * \param stage_to_axes A pointer to StageToAxesMap. * \return The iterator results after split. */ - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; + std::vector ApplyToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes) const; - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, + std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, const std::vector& transform_steps) const final; static constexpr const char* _type_key = "ansor.SplitStep"; @@ -180,15 +176,14 @@ class SplitStep : public Step { * \param lengths The extent length of the axis to split. * \param inner_to_outer The split direction. */ - SplitStep(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, + SplitStep(int stage_id, int iter_id, PrimExpr extent, const std::vector& lengths, bool inner_to_outer); TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; /*! \brief Fuse step that corresponds to te::Stage::fuse */ -class FuseStepNode: public StepNode { +class FuseStepNode : public StepNode { public: /*! \brief The ids of iterators to fuse. */ std::vector fused_ids; @@ -199,12 +194,10 @@ class FuseStepNode: public StepNode { * \param stage_to_axes A pointer to StageToAxesMap. * \return The iterator result after fuse. */ - tir::IterVar ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; + tir::IterVar ApplyToSchedule(std::vector* stages, StageToAxesMap* stage_to_axes) const; - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, + std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, const std::vector& transform_steps) const final; static constexpr const char* _type_key = "ansor.FuseStep"; @@ -242,10 +235,10 @@ struct hash<::tvm::ansor::Step> { ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->after_ids)); } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { - size_t ret = ::dmlc::HashCombine(2, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->inner_to_outer))); + size_t ret = ::dmlc::HashCombine( + 2, ::dmlc::HashCombine( + std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->inner_to_outer))); for (const auto& len : ps->lengths) { if (len.defined()) { auto pint = len.as<::tvm::tir::IntImmNode>(); @@ -257,9 +250,8 @@ struct hash<::tvm::ansor::Step> { } return ret; } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine(3, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->fused_ids)); + return ::dmlc::HashCombine( + 3, ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->fused_ids)); } else { LOG(FATAL) << "Invalid step"; } From 682380228eb79c6f9315a4ac9d4330f96e9165d4 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 29 Jun 2020 15:12:26 +0800 Subject: [PATCH 47/78] pylint fix --- python/tvm/ansor/auto_schedule.py | 2 -- python/tvm/ansor/loop_state.py | 7 +++---- python/tvm/ansor/measure.py | 32 +++++++++++++++---------------- python/tvm/ansor/serialization.py | 3 +-- python/tvm/ansor/utils.py | 7 +++---- 5 files changed, 23 insertions(+), 28 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 0e88b8183686..dcd53b282d6e 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -17,8 +17,6 @@ """User interface for auto-scheduler""" -import random - import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index d08d27300f16..565ae66435c4 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -182,12 +182,11 @@ def fuse(self, stage_id, iters): def _resolve_stage_id(self, stage_id): if isinstance(stage_id, Operation): return self.stage_id_map[stage_id] - elif isinstance(stage_id, tvm.te.Tensor): + if isinstance(stage_id, tvm.te.Tensor): return self.stage_id_map[stage_id.op] - elif isinstance(stage_id, int): + if isinstance(stage_id, int): return stage_id - else: - raise ValueError("Invalid stage_id") + raise ValueError("Invalid stage_id") def _update_stage_id_map(self): if not self.stages_cache: diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index d99fb3f8995a..9bc82b543c7b 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -22,11 +22,10 @@ We implement these in python to utilize python's multiprocessing and error handling """ -from typing import List + import os import time import shutil -import logging import traceback import tempfile import multiprocessing @@ -35,17 +34,14 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform -from tvm.contrib import tar +from tvm.contrib import tar, ndk from . import _ffi_api from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout -LOGGER = logging.getLogger('ansor') - # The maximum length of error message MAX_ERROR_MSG_LEN = 512 - @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): """Base class for measurement callback function""" @@ -225,15 +221,15 @@ def make_error_msg(): return error_msg -global global_build_arguments -global global_run_arguments +GLOBAL_BUILD_ARGUMENTS = None +GLOBAL_RUN_ARGUMENTS = None def local_build_worker(index): """ Local builder function """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool - measure_inputs, build_func, timeout, verbose = global_build_arguments + measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS assert isinstance(build_func, str) if build_func == 'default': build_func = tar.tar @@ -254,6 +250,7 @@ def timed_func(): try: sch, args = task.compute_dag.apply_steps_from_state( inp.state) + # pylint: disable=W0703 except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR error_msg = make_error_msg() @@ -268,6 +265,7 @@ def timed_func(): func = build_module.build( sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) + # pylint: disable=W0703 except Exception: error_no = MeasureErrorNo.COMPILE_HOST error_msg = make_error_msg() @@ -295,8 +293,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): """ Local builder build function """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool - global global_build_arguments - global_build_arguments = (inputs, build_func, timeout, verbose) + global GLOBAL_BUILD_ARGUMENTS + GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose) pool = NoDaemonPool(n_parallel) tuple_res = pool.map(local_build_worker, range(len(inputs))) @@ -314,7 +312,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose): """ Local runner run function """ - MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log def timed_func(inp, build_res): tic = time.time() @@ -325,8 +323,9 @@ def timed_func(inp, build_res): ctx = ndarray.context(str(inp.task.target), 0) time_f = func.time_evaluator( func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + # pylint: disable=W0703 except Exception: - costs = (MAX_FLOAT,) + costs = (max_float,) error_no = MeasureErrorNo.COMPILE_DEVICE error_msg = make_error_msg() @@ -337,8 +336,9 @@ def timed_func(inp, build_res): ctx.sync() costs = time_f(*args).results + # pylint: disable=W0703 except Exception: - costs = (MAX_FLOAT,) + costs = (max_float,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_error_msg() @@ -358,7 +358,7 @@ def timed_func(inp, build_res): "Measure input size should be equal to build results" for inp, build_res in zip(inputs, build_results): if build_res.error_no != 0: - res = (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + res = (max_float,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ time.time() else: res = call_func_with_timeout( @@ -366,7 +366,7 @@ def timed_func(inp, build_res): if isinstance(res, TimeoutError): if verbose >= 1: print("*T", end="") # Run timeout - res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, \ + res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \ build_res.time_cost + timeout, time.time() measure_results.append(MeasureResult(*res)) diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index deae5df68229..1546bb693076 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -22,7 +22,6 @@ import tvm._ffi from tvm.runtime import Object from .measure import MeasureCallback, MeasureErrorNo -from .loop_state import State from . import _ffi_api @@ -62,7 +61,7 @@ def read_lines(self, max_size=-1, skip_size=0): def __iter__(self): while True: ret = _ffi_api.LogReaderReadNext(self) - if ret is None or not len(ret): + if not ret: break yield ret[0], ret[1] # (input, result) diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index e83a9deb3f0a..041327d147d5 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -21,16 +21,12 @@ import multiprocessing.pool import queue import signal -import threading -import os -import numpy as np try: import psutil except ImportError: psutil = None -from tvm import rpc from tvm.tir import expr from tvm.tir.transform import Simplify from tvm.ir.transform import Sequential @@ -113,6 +109,9 @@ def __init__(self, *args, **kwargs): kwargs['context'] = NoDaemonContext() super().__init__(*args, **kwargs) + def __reduce__(self): + pass + def kill_child_processes(parent_pid, sig=signal.SIGTERM): """kill all child processes recursively""" From a82dbb80758c54d6b6c55293e095a8f337a1c791 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 29 Jun 2020 15:18:09 +0800 Subject: [PATCH 48/78] Update --- src/ansor/transform_step.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index b18b0b5248e5..1fc00d3f6ff0 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -230,15 +230,14 @@ namespace std { template <> struct hash<::tvm::ansor::Step> { std::size_t operator()(const ::tvm::ansor::Step& step) const { + // clang-format off if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { return ::dmlc::HashCombine(1, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->after_ids)); + ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->after_ids)); } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { - size_t ret = ::dmlc::HashCombine( - 2, ::dmlc::HashCombine( - std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->inner_to_outer))); + size_t ret = ::dmlc::HashCombine(2, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->inner_to_outer))); for (const auto& len : ps->lengths) { if (len.defined()) { auto pint = len.as<::tvm::tir::IntImmNode>(); @@ -250,12 +249,13 @@ struct hash<::tvm::ansor::Step> { } return ret; } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine( - 3, ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->fused_ids)); + return ::dmlc::HashCombine(3, + ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->fused_ids)); } else { LOG(FATAL) << "Invalid step"; } return 0; + // clang-format on } }; } // namespace std From ac36c4615eecf7502dbdefa67d00f3c886a8c59d Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 29 Jun 2020 20:07:14 +0800 Subject: [PATCH 49/78] Doc update --- python/tvm/ansor/auto_schedule.py | 2 +- python/tvm/ansor/compute_dag.py | 15 +++- python/tvm/ansor/loop_state.py | 30 ++++---- python/tvm/ansor/measure.py | 33 ++++---- python/tvm/ansor/serialization.py | 32 ++++++-- python/tvm/ansor/workload_registry.py | 97 +++++++++++++++++++++--- src/ansor/compute_dag.cc | 15 ++-- src/ansor/compute_dag.h | 23 ++++-- src/ansor/loop_state.cc | 13 +--- src/ansor/loop_state.h | 9 ++- src/ansor/measure.h | 2 +- src/ansor/search_policy/empty_policy.cc | 5 ++ src/ansor/search_policy/search_policy.cc | 2 - src/ansor/search_policy/search_policy.h | 5 +- src/ansor/serialization.h | 2 +- 15 files changed, 203 insertions(+), 82 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index dcd53b282d6e..7a9d7c322c9e 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -50,7 +50,7 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, @tvm._ffi.register_object("ansor.SearchTask") class SearchTask(Object): - """ The meta-information of a search task + """ The meta-information of a search task. Parameters ---------- diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 41ba40a7f481..1e289aaafe0c 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -47,7 +47,7 @@ def get_init_state(self): def apply_steps_from_state(self, state): """ - Apply transform steps according to the history of a state + Apply transform steps according to the history of a State. Parameters ---------- @@ -63,7 +63,7 @@ def apply_steps_from_state(self, state): def print_python_code_from_state(self, state): """ - Print transform steps in the history of a state as TVM's python schedule primitive + Print transform steps in the history of a State as TVM's python schedule primitive. Parameters ---------- @@ -79,7 +79,16 @@ def print_python_code_from_state(self, state): def infer_bound_from_state(self, state): """ - Infer bound for a state + Infer bound for a state using TVM schedule. + + State api supports to define a split step with its split factor to be a blank placeholder, + so sometimes we may get a State will incomplete iterator extent information. + And another situation is after some steps (for exp. compute_at), it may be hard to track + the extent change of all iterators. + + We perform infer bound using TVM schedule and fill the State with those informations. After + applying this methods, the State is guaranteed to have complete interator extent + information. Parameters ---------- diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 565ae66435c4..121a2e70b7ac 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -42,7 +42,7 @@ @tvm._ffi.register_object("ansor.Iterator") class Iterator(Object): - """A for loop iterator""" + """ A loop iterator structure. """ @tvm._ffi.register_object("ansor.Stage") @@ -63,7 +63,7 @@ def iters(self): @tvm._ffi.register_object("ansor.State") class StateObject(Object): - """The internal State object """ + """ The internal State object """ def __eq__(self, other): return _ffi_api.StateEqual(self, other) @@ -73,6 +73,8 @@ class State: A state in the search process. It consists of the current loop structure and the history steps to reach this state. + Each State corresponds to a specific schedule for the target ComputeDAG. + Parameters ---------- state_object : StateObject @@ -115,12 +117,13 @@ def stage_ops(self): return [stage.op for stage in self.stages_cache] def transform_steps_size(self): - """ Return the size of transform_steps + """ Return the size of current transform_steps """ return _ffi_api.StateGetTransformStepsSize(self.state_object) def reorder(self, stage_id, order): - """ + """ Schedule primitive corresponds to te.reorder. + Parameters ---------- stage_id : Union[int, Operation, Tensor] @@ -134,7 +137,8 @@ def reorder(self, stage_id, order): self._clear_cache() def split(self, stage_id, iterator, lengths, inner_to_outer=True): - """ + """ Schedule primitive corresponds to te.split. + Parameters ---------- stage_id : Union[int, Operation, Tensor] @@ -160,7 +164,8 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): return res def fuse(self, stage_id, iters): - """ + """ Schedule primitive corresponds to te.fuse. + Parameters ---------- stage_id : Union[int, Operation, Tensor] @@ -179,6 +184,12 @@ def fuse(self, stage_id, iters): self._clear_cache() return res + def copy(self): + """ Do deep copy of this State. """ + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state + def _resolve_stage_id(self, stage_id): if isinstance(stage_id, Operation): return self.stage_id_map[stage_id] @@ -197,13 +208,6 @@ def _update_stage_id_map(self): def _clear_cache(self): self.stages_cache = None - def copy(self): - """ Do deep copy of this State. - """ - state = State(self.state_object, self.compute_dag) - state.stage_id_map = self.stage_id_map.copy() - return state - def __getitem__(self, key): if not self.stages_cache: self.stages_cache = _ffi_api.StateGetStages(self.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 9bc82b543c7b..06e7c6fa1a4e 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -44,12 +44,13 @@ @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): - """Base class for measurement callback function""" + """ Base class for measurement callback function. """ @tvm._ffi.register_object("ansor.MeasureInput") class MeasureInput(Object): - """ + """ Store the input of a measurement. + Parameters ---------- task : SearchTask @@ -57,14 +58,13 @@ class MeasureInput(Object): state : State The current State to be measured. """ - def __init__(self, task, state): self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) @tvm._ffi.register_object("ansor.BuildResult") class BuildResult(Object): - """ Store the input of a build. + """ Store the result of a build. Parameters ---------- @@ -79,7 +79,6 @@ class BuildResult(Object): time_cost : Float The time cost of build. """ - def __init__(self, filename, args, error_no, error_msg, time_cost): self.__init_handle_by_constructor__( _ffi_api.BuildResult, filename if filename else "", args, error_no, @@ -88,7 +87,8 @@ def __init__(self, filename, args, error_no, error_msg, time_cost): @tvm._ffi.register_object("ansor.MeasureResult") class MeasureResult(Object): - """ + """ Store the results of a measurement. + Parameters ---------- costs : List[Float] @@ -102,7 +102,6 @@ class MeasureResult(Object): timestamp : Float The time stamps of this measurement. """ - def __init__(self, costs, error_no, error_msg, all_cost, timestamp): self.__init_handle_by_constructor__( _ffi_api.MeasureResult, costs, error_no, @@ -111,9 +110,11 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.Builder") class Builder(Object): - """ Base class of Builder """ + """ Base class of Builder. """ + def build(self, measure_inputs, verbose=1): - """ + """ Build programs and return results. + Parameters ---------- measure_inputs : List[MeasureInput] @@ -131,8 +132,10 @@ def build(self, measure_inputs, verbose=1): @tvm._ffi.register_object("ansor.Runner") class Runner(Object): """ Base class of Runner """ + def run(self, measure_inputs, build_results, verbose=1): - """ + """ Run measurement and return results. + Parameters ---------- measure_inputs : List[MeasureInput] @@ -198,7 +201,7 @@ def __init__(self, class MeasureErrorNo(object): - """Error type for MeasureResult""" + """ Error type for MeasureResult. """ NO_ERROR = 0 # No error INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state # Errors happen when compiling code on host (e.g. tvm.build) @@ -213,7 +216,7 @@ class MeasureErrorNo(object): def make_error_msg(): - """ Get the error message from traceback """ + """ Get the error message from traceback. """ error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ @@ -226,7 +229,7 @@ def make_error_msg(): def local_build_worker(index): - """ Local builder function """ + """ Local builder function. """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS @@ -290,7 +293,7 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): - """ Local builder build function """ + """ Local builder build function. """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global GLOBAL_BUILD_ARGUMENTS @@ -311,7 +314,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): @tvm._ffi.register_func("ansor.local_runner.run") def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose): - """ Local runner run function """ + """ Local runner run function. """ max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log def timed_func(inp, build_res): diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 1546bb693076..9db85dc98ef9 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -28,14 +28,13 @@ @tvm._ffi.register_object("ansor.LogToFile") class LogToFile(MeasureCallback): """ - A measurement callback that writes measurement records into a file + A measurement callback that writes measurement records into a file. Parameters ---------- filename : Str File name for this callback to write log to. """ - def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) @@ -43,7 +42,7 @@ def __init__(self, filename="ansor_tuning.json"): @tvm._ffi.register_object("ansor.LogReader") class LogReader(Object): """ - Reader of the json log file + Reader of the json log file. Parameters ---------- @@ -54,6 +53,22 @@ def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) def read_lines(self, max_size=-1, skip_size=0): + """ Read multiple lines from the log file. + + Parameters + ---------- + max_size : Int + The maximum number of lines. -1 means read all lines. + skip_size : Int + Skip the first n lines. + + Returns + ------- + inputs : List[MeasureInput] + The MeasureInputs loaded from the log file. + results : List[MeasureResult] + The MeasureResults loaded from the log file. + """ inputs, results = _ffi_api.LogReaderReadLines( self, max_size, skip_size) return inputs, results @@ -74,13 +89,17 @@ def load_from_file(filename: str): ---------- filename : Str File name to load log from. + + Returns + ------- + logs : List[MeasureInput, MeasureResult] """ return zip(*LogReader(filename).read_lines()) def write_measure_records_to_file(filename, inputs, results): """ - Write(append) measure records to file + Write(append) measure records to file. Parameters ---------- @@ -107,7 +126,10 @@ def best_measure_pair_in_file(filename, workload_key=None, target=None): Returns ------- - The best state from this log fine in form (MeasureInput, MeasureResult). + input : MeasureInput + The best State's MeasureInput from this log fine. + result : MeasureResult + The best State's MeasureResult from this log fine. """ log_reader = LogReader(filename) best_cost = 1e30 diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index d6df6f36f046..d423c689bf99 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -65,7 +65,17 @@ def matmul(N, M, K): def compute_dag_hash(dag): - """ Get hash value for a ComputeDAG + """ Get hash value for a ComputeDAG. + + Parameters + ---------- + dag : ComputeDAG + The target ComputeDAG. + + Returns + ------- + hash_value : Str + The hash value of this ComputeDAG in hex digest. """ # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG str_key = '' @@ -87,8 +97,19 @@ def compute_dag_hash(dag): def register_workload_bufs(bufs): - """Directly register buffers of a workload and return the workload_key - The buffers can be looked up with workload_key_to_tensors by the workload_key + """ Directly register buffers of a workload and return the workload_key. + + The buffers can be looked up with workload_key_to_tensors by the workload_key. + + Parameters + ---------- + bufs : List[Tensor] + A list of Tensors for the target compute declaration. + + Returns + ------- + workload_key : Str + A workload key mapping to the registered compute declaration. """ dag = ComputeDAG(bufs) key = compute_dag_hash(dag) @@ -133,7 +154,18 @@ def deserialize_args(args): @tvm._ffi.register_func("ansor.workload_key_to_tensors") def workload_key_to_tensors(workload_key): - """Decode a workload key to the input/output tensors""" + """ Decode a workload key to the input/output tensors. + + Parameters + ---------- + workload_key : Str + The target workload key. + + Returns + ------- + tensors : List[Tensor] + The registered compute declaration Tensors. + """ workload = json.loads(workload_key) name = workload[0] lookup = WORKLOAD_FUNC_REGISTRY[name] @@ -146,13 +178,37 @@ def workload_key_to_tensors(workload_key): @ tvm._ffi.register_func("ansor.workload_key_to_dag") def workload_key_to_dag(workload_key): - """Decode a workload key to a compute dag""" + """ Decode a workload key to a compute dag. + + Parameters + ---------- + workload_key : Str + The target workload key. + + Returns + ------- + dag : ComputeDAG + ComputeDAG to the registered compute declaration. + """ tensors = workload_key_to_tensors(workload_key) return ComputeDAG(tensors) def make_workload_key_func(func, args): - """make a workload key from function and arguments""" + """ make a workload key from function and arguments. + + Parameters + ---------- + func : Function + The target function that returns the compute declaration Tensors. + args : Args + The args of the target function. + + Returns + ------- + workload_key : Str + The workload key of the target function. + """ args = serialize_args(args) if callable(func): @@ -169,21 +225,44 @@ def make_workload_key_func(func, args): def make_workload_key_bufs(bufs): - """make a workload key from bufs""" + """ make a workload key from bufs. + + Parameters + ---------- + bufs : List[Tensor] + A list of Tensors for the target compute declaration. + + Returns + ------- + workload_key : Str + A workload key mapping to the registered compute declaration. + """ dag = ComputeDAG(bufs) key = compute_dag_hash(dag) return json.dumps((key,)) def dump_workload_func_registry(filename): - """Dump workload function registry to a pickle binary file""" + """ Dump workload function registry to a pickle binary file. + + Parameters + ---------- + filename : Str + The filename to dump workload function registry to. + """ global WORKLOAD_FUNC_REGISTRY pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb')) def load_workload_func_registry(filename): - """Load workload function registry from a pickle binary file""" + """ Load workload function registry from a pickle binary file. + + Parameters + ---------- + filename : Str + The filename to load workload function registry from. + """ global WORKLOAD_FUNC_REGISTRY WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index c8fd7f8f6d37..644696fc2a82 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -226,10 +226,6 @@ class FlopEstimator: public ExprFunctor { bool fail{false}; }; -State ComputeDAG::GetInitState() const { - return Downcast(operator->()->init_state); -} - ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; @@ -261,6 +257,8 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { data_ = std::move(node); } +State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } + std::pair > ComputeDAG::ApplySteps( const std::vector& transform_steps) const { std::vector stages; @@ -309,8 +307,7 @@ std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_st return ss.str(); } -State ComputeDAG::ReplayAndInferBound( - const std::vector& transform_steps) const { +State ComputeDAG::ReplayAndInferBound(const std::vector& transform_steps) const { State ret_state = GetInitState(); StateNode* pstate = ret_state.CopyOnWrite(); pstate->transform_steps = transform_steps; @@ -364,6 +361,7 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize(); + // Get bound information from TVM schedule bounds = te::InferBound(sch); // Update the state bound information @@ -396,9 +394,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } std::pair > ComputeDAG::ReplaySteps( - const std::vector &transform_steps, - std::vector *stages, - StageToAxesMap *stage_to_axes) const { + const std::vector& transform_steps, std::vector* stages, + StageToAxesMap* stage_to_axes) const { std::vector ops; for (const auto& op : operator->()->ops) { if (!op->IsInstance()) { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 25841ff4f268..ad98b38479f3 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -38,7 +38,9 @@ namespace tvm { namespace ansor { -class StateNode; class State; class Step; +class StateNode; +class State; +class Step; typedef std::unordered_map, ObjectHash, ObjectEqual> StageToAxesMap; @@ -57,9 +59,9 @@ class ComputeDAGNode : public Object { Array tensors; /*! \brief All related operations in topo order. */ Array ops; - /*! \brief Number of float operations. */ + /*! \brief Number of total float operations for this ComputeDAG. */ double flop_ct; - /*! \brief The initial state. */ + /*! \brief The initial state without any transform steps. */ ObjectRef init_state; void VisitAttrs(tvm::AttrVisitor* v) { @@ -103,8 +105,14 @@ class ComputeDAG: public ObjectRef { std::string PrintStepsAsPython(const std::vector& transform_steps) const; /*! - * \brief Replay the transform steps and call ir_pass::InferBound to fill - * correct bound information. + * \brief Replay the transform steps and call ir_pass::InferBound to fill correct bound + * information. + * State api supports to define a split step with its split factor to be a blank placeholder, + * so sometimes we may get a State will incomplete iterator extent information. + * And another situation is after some steps (for exp. compute_at), it may be hard to track the + * extent change of all iterators. + * We perform infer bound using TVM schedule and fill the State with those informations. After + * applying this methods, the State is guaranteed to have complete interator extent information. * \param transform_steps Transform steps of the target state. * \return The State after inferbound. */ @@ -118,7 +126,7 @@ class ComputeDAG: public ObjectRef { /*! * \brief Fill the correct bound information for a list of given states. * Return the new states inplace. - * \param states A pointer to a State vector. + * \param states A pointer to a State vector, States are updated inplace. */ void InferBound(std::vector* states) const; @@ -133,7 +141,8 @@ class ComputeDAG: public ObjectRef { private: /*! - * \brief Internal common parts for replaying steps. + * \brief Internal common parts for replaying steps. This is the key method to apply steps to + * TVM schedule. * \param transform_steps Transform steps of the target state. * \param stages A pointer to `te::Stage` vector. * \param stage_to_axes A pointer to StageToAxesMap. diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 6d877a21a6a6..000b76a69922 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -62,7 +62,6 @@ Stage::Stage(te::Operation op) { if (op->IsInstance()) { node->op_type = kCompute; auto* pop = op.as(); - for (const auto& axis : pop->axis) { node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); @@ -129,9 +128,8 @@ State::State(const std::vector& stages, const std::vector& transfor /********** Schedule primitives apis for state **********/ void State::reorder(int stage_id, const std::vector& order) { const Stage& stage = operator->()->stages[stage_id]; - CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " - "should be specified"; + << "should be specified"; std::vector after_ids; GetIndices(stage->iters, order, &after_ids); ReorderStep step = ReorderStep(stage_id, after_ids); @@ -143,7 +141,6 @@ std::vector State::split(int stage_id, const Iterator& it, const std::vector& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; - SplitStep step = SplitStep(stage_id, GetIndex(stage->iters, it), it->range.defined() ? it->range->extent : PrimExpr(), @@ -164,12 +161,10 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { /********** Step implementations for state **********/ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - std::vector iters; for (auto x : step->after_ids) { iters.push_back(stage->iters[x]); } - StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = Stage(stage->op, stage->op_type, std::move(iters), stage->compute_at, stage->attrs); @@ -284,14 +279,12 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::make_by_min_extent(0, new_extent); } - Iterator new_it = - Iterator(new_name, range, new_iter_type, kNone, &ori_iters); + Iterator new_it = Iterator(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), - stage->iters.begin() + step->fused_ids.back() + 1, + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 94f3faf71b45..dab32d46156d 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -127,7 +127,7 @@ class IteratorNode : public Object { IteratorAnnotation annotation; /*! \brief The original iterators before fusion. */ std::vector ori_iters; - /*! \brief The extra attribute of this iterator. */ + /*! \brief The extra attributes of this iterator. */ std::string attr; void VisitAttrs(tvm::AttrVisitor* v) { @@ -172,7 +172,7 @@ struct StageAttributes { }; /*! - * \brief A stage in the compute declaration. + * \brief A op stage in the compute declaration. * Similar to te::Stage in `include/schedule.h`. */ class StageNode : public Object { @@ -235,8 +235,9 @@ class Stage : public ObjectRef { }; /*! - * \brief A state in the search process. - * It consists of the current loop structure and the history steps to reach this state. + * \brief A State in the search process. + * It consists of the current loop structure and the history steps to reach this State. + * Each State corresponds to a specific schedule for the target ComputeDAG. */ class StateNode: public Object { public: diff --git a/src/ansor/measure.h b/src/ansor/measure.h index aae332386367..ee71ad558680 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -101,7 +101,7 @@ class MeasureInput : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode); }; -/*! \brief Store the input of a build. */ +/*! \brief Store the result of a build. */ class BuildResultNode : public Object { public: /*! \brief The filename of built binary file. */ diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index 287b68c7c0ce..d4ebc829f7a8 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -48,6 +48,7 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, if (n_trials <= 1) { const auto& res = SearchOneRound(); CHECK_GT(res.size(), 0); + return res[0]; } else { std::vector inputs; @@ -60,10 +61,14 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, while (ct < n_trials) { const auto& res = SearchOneRound(); ct += res.size(); + // Build MeasureInputs for measuring inputs.clear(); for (const auto& state : res) { + // The class members measured_states_set_ provided by SearchPolicy can be used to filter + // out the already measured states inputs.emplace_back(cur_task, state); } + // ProgramMeasurer will record the state with best performance during measure process measurer->Measure(cur_task, GetRef(this), inputs, &results); } diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 41a0e650fb1e..7d782cb0eba2 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -26,8 +26,6 @@ #include -// #include "../serialization.h" - namespace tvm { namespace ansor { diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 8389320eb085..5f43f5352695 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -102,7 +102,8 @@ class SearchPolicyNode : public Object { } /*! - * \brief Do schedule search for a task. + * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state + * get during the search process. * \param task The target search task. * \param n_trials Total schedules to be tried during this search. * \param early_stopping Early stop if no better schedule is found. @@ -117,7 +118,7 @@ class SearchPolicyNode : public Object { Array pre_search_callbacks) = 0; /*! - * \brief Call SearchCallback with the current SearchPolicyNode.u + * \brief Call SearchCallback with the current SearchPolicyNode * \param callbacks SearchCallback to be called. */ void RunCallbacks(const Array& callbacks); diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index fbe79270bb70..f8ab6b42dda2 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -82,7 +82,7 @@ class LogReaderNode : public Object { /*! * \brief Read multiple lines from the log file. * \param max_size The maximum number of lines. -1 means read all lines. - * \param skip_size Skip the first n lines + * \param skip_size Skip the first n lines. * \return The MeasureInputs and MeasureResults loaded from the log file. */ std::pair, Array > ReadLines( From a62b1e0a7a5dfea4b884900789c542a89245cfbe Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 30 Jun 2020 13:03:54 +0800 Subject: [PATCH 50/78] Update --- python/tvm/ansor/loop_state.py | 20 ++++++++++---------- src/ansor/compute_dag.h | 4 ++-- src/ansor/loop_state.cc | 1 + src/ansor/loop_state.h | 23 +++++++++++------------ src/ansor/transform_step.h | 5 ++++- 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 121a2e70b7ac..a796373b9393 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -19,19 +19,19 @@ """ The definition of the "state" in search. A state consists a current loop structure and the transform history to reach its current loop structure. -To enable flexible manipulation of the loop structure, we implemented a lightweight -loop structure IR (Intermediate Representation) specifically for search. +To enable flexible manipulation of the loop structures, we implemented a lightweight loop +structure IR (Intermediate Representation) based on the original TVM IR but specifically +for schedule search. -Basically this is a simplified TVM IR with schedule primitives. -We don't use the existing TVM IR because -1. We want fast incremental change to the loop structures -2. We want serializable transformation history for replay, backtracking, and mutation -3. We may create some new macro schedule primitives +We don't use the existing TVM IR but to extend a new Sketch IR on it is because: +1. We want fast incremental change to the loop structures; +2. We want serializable transform history for replay, backtracking, and mutation; +3. We may create some macro schedule primitives that represent the combination of several +TVM schedule primitives. After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. -Because we share a lot common objects during search, the transformation is -implemented in copy on write style. All objects are immutable, which is -similar to TVM IR. +Because we share a lot common objects during search, the transformation is implemented in +copy on write style. All objects are immutable, which is similar to TVM IR. """ import tvm._ffi diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index ad98b38479f3..e07dfb8c433c 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -20,8 +20,8 @@ /*! * \file ansor/compute_dag.h * \brief Compute declaration graph and its related analysis tools. - * ComputeDAG is responsible for the interaction with the original TVM schedule system, to apply - * state to a runable TVM schedule or provide the schedule Python code. + * ComputeDAG is also responsible for the interaction with the original TVM schedule system, to + * apply state to a runable TVM schedule or provide the schedule Python code. */ #ifndef TVM_ANSOR_COMPUTE_DAG_H_ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 000b76a69922..1c0228be4755 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -25,6 +25,7 @@ #include "loop_state.h" +#include #include #include diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index dab32d46156d..721478de0b7a 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -21,20 +21,19 @@ * \file ansor/loop_state.h * \brief The definition of the "state" in search. A state consists the current loop structure * and the transform history to reach its current loop structure. - * To enable flexible manipulation of the loop structure, we implemented a lightweight - * loop structure IR (Intermediate Representation) specifically for search. This can be seen as - * a preview of how this schedule looks like after tvm.lower or tvm.build. + * To enable flexible manipulation of the loop structures, we implemented a lightweight loop + * structure IR (Intermediate Representation) based on the original TVM IR but specifically + * for schedule search. * - * Basically this is a simplified TVM IR with schedule primitives. - * We don't use the existing TVM IR because - * 1. We want fast incremental change to the loop structures - * 2. We want serializable transformation history for replay, backtracking, and mutation. - * 3. We may create some macro schedule primitives + * We don't use the existing TVM IR but to extend a new Sketch IR on it is because: + * 1. We want fast incremental change to the loop structures; + * 2. We want serializable transform history for replay, backtracking, and mutation; + * 3. We may create some macro schedule primitives that represent the combination of several + * TVM schedule primitives. * - * After the search is done, we will lower this IR to TVM IR with TVM schedule primitives. - * Because we share a lot common objects during search, the transformation is - * implemented in copy on write style. - * All objects are immutable, which is similar to TVM IR. + * After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. + * Because we share a lot common objects during search, the transformation is implemented in + * copy on write style. All objects are immutable, which is similar to TVM IR. */ #ifndef TVM_ANSOR_LOOP_STATE_H_ diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 1fc00d3f6ff0..3536024f46eb 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -61,7 +61,10 @@ typedef std::unordered_map, ObjectHash class Step; -/*! \brief The base class for a transformation step */ +/*! + * \brief The base class for a transformation step. Each step has its corresponding tvm.te + * schedule primitives. + */ class StepNode : public Object { public: /*! \brief The index of the target stage. */ From 526cf42c6034415b601e8b09dcc19d21f277e7a5 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 30 Jun 2020 14:32:31 +0800 Subject: [PATCH 51/78] Bug fix after code merge to the new master --- src/ansor/loop_state.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 1c0228be4755..7e7e4c22cfa4 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -199,8 +199,7 @@ std::vector State::DoSplitStepCommon( } Iterator res; if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { - res = Iterator(name, Range::make_by_min_extent(tosplit_min, l), - it->iter_type, kNone); + res = Iterator(name, Range::FromMinExtent(tosplit_min, l), it->iter_type, kNone); tosplit_min = 0; tosplit_extent = indexdiv(tosplit_extent + l - 1, l); } else { @@ -212,7 +211,7 @@ std::vector State::DoSplitStepCommon( Range range; if (tosplit_min.defined() && tosplit_extent.defined()) { - range = Range::make_by_min_extent(tosplit_min, tosplit_extent); + range = Range::FromMinExtent(tosplit_min, tosplit_extent); } if (inner_to_outer) { outs.push_back( @@ -278,7 +277,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { Range range; if (new_extent.defined()) { - range = Range::make_by_min_extent(0, new_extent); + range = Range::FromMinExtent(0, new_extent); } Iterator new_it = Iterator(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; From 426ec82e12eefaba6271e39eb35bbd082c16c2ff Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 30 Jun 2020 15:46:16 +0800 Subject: [PATCH 52/78] clang-format fix --- src/ansor/auto_schedule.cc | 26 ++- src/ansor/auto_schedule.h | 14 +- src/ansor/compute_dag.cc | 169 +++++++++--------- src/ansor/compute_dag.h | 10 +- src/ansor/loop_state.cc | 125 ++++++------- src/ansor/loop_state.h | 29 ++- src/ansor/measure.cc | 172 +++++++++--------- src/ansor/measure.h | 36 ++-- src/ansor/search_policy/search_policy.h | 4 +- src/ansor/search_task.cc | 39 ++-- src/ansor/search_task.h | 10 +- src/ansor/serialization.cc | 225 ++++++++++++------------ src/ansor/serialization.h | 13 +- src/ansor/transform_step.cc | 77 ++++---- src/ansor/utils.h | 52 +++--- 15 files changed, 468 insertions(+), 533 deletions(-) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index d409988a8007..c06080b5cb32 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -50,12 +50,11 @@ TuneOption::TuneOption(int n_trials, int early_stopping, int num_measure_per_rou } std::pair > AutoSchedule(SearchTask task, - SearchPolicy search_policy, TuneOption tune_option) { + SearchPolicy search_policy, + TuneOption tune_option) { // Create a ProgramMeasurer to handle the schedule build and performance measure - ProgramMeasurer measurer = - ProgramMeasurer(tune_option->builder, tune_option->runner, - tune_option->measure_callbacks, - tune_option->verbose); + ProgramMeasurer measurer = ProgramMeasurer(tune_option->builder, tune_option->runner, + tune_option->measure_callbacks, tune_option->verbose); // Search for the best schedule State state = search_policy->Search(task, tune_option->n_trials, tune_option->early_stopping, tune_option->num_measure_per_round, tune_option->verbose, @@ -63,18 +62,17 @@ std::pair > AutoSchedule(SearchTask task, return task->compute_dag.ApplySteps(state->transform_steps); } -std::pair > AutoSchedule( - std::string workload_key, Target target, Target target_host, - SearchPolicy search_policy, HardwareParams hardware_params, - TuneOption tune_option) { +std::pair > AutoSchedule(std::string workload_key, Target target, + Target target_host, + SearchPolicy search_policy, + HardwareParams hardware_params, + TuneOption tune_option) { // Create SearchTask from the given workload key ComputeDAG dag = ComputeDAG(workload_key); - SearchTask task = SearchTask( - std::move(dag), std::move(workload_key), std::move(target), - std::move(target_host), std::move(hardware_params)); + SearchTask task = SearchTask(std::move(dag), std::move(workload_key), std::move(target), + std::move(target_host), std::move(hardware_params)); // Search for the best schedule - return AutoSchedule(std::move(task), std::move(search_policy), - std::move(tune_option)); + return AutoSchedule(std::move(task), std::move(search_policy), std::move(tune_option)); } TVM_REGISTER_GLOBAL("ansor.TuneOption") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index a7bfb8449537..4ec0b99887c3 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -102,8 +102,9 @@ class TuneOption : public ObjectRef { * \param tune_option Tuning and measurement options. * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. */ -std::pair > AutoSchedule( - SearchTask task, SearchPolicy search_policy, TuneOption tune_option); +std::pair > AutoSchedule(SearchTask task, + SearchPolicy search_policy, + TuneOption tune_option); /*! * \brief Auto schedule search for a given compute declaration, by workload key. @@ -115,10 +116,11 @@ std::pair > AutoSchedule( * \param tune_option Tuning and measurement options. * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. */ -std::pair > AutoSchedule( - std::string workload_key, Target target, Target target_host, - SearchPolicy search_policy, HardwareParams hardware_params, - TuneOption tune_option); +std::pair > AutoSchedule(std::string workload_key, Target target, + Target target_host, + SearchPolicy search_policy, + HardwareParams hardware_params, + TuneOption tune_option); } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 644696fc2a82..ddcefbd81641 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -69,8 +69,7 @@ void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops -void TopoSortOps(const Array& tensors, - std::vector* ops) { +void TopoSortOps(const Array& tensors, std::vector* ops) { std::unordered_map degree; std::unordered_map > edge_set; std::unordered_map priority; @@ -113,9 +112,7 @@ void TopoSortOps(const Array& tensors, ops->clear(); using Item = std::pair; - auto cmp = [](const Item& left, const Item& right) { - return left.second < right.second; - }; + auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; }; std::priority_queue, decltype(cmp)> queue(cmp); for (const auto& iter : degree) { if (iter.second == 0) { @@ -138,7 +135,7 @@ void TopoSortOps(const Array& tensors, } // Estimate number of float operations in an expression -class FlopEstimator: public ExprFunctor { +class FlopEstimator : public ExprFunctor { public: double EstimateFlop(const Array& ops) { double ret = 0; @@ -190,29 +187,37 @@ class FlopEstimator: public ExprFunctor { double VisitExpr_(const VarNode* op) final { return 0.0; } double VisitExpr_(const SelectNode* op) final { - return VisitExpr(op->condition) + std::max(VisitExpr(op->true_value), - VisitExpr(op->false_value)); - } - -#define VisitBinary(Node) \ - double VisitExpr_(const Node* op) final { \ - return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); \ - } -#define VisitUnary(Node) \ - double VisitExpr_(const Node* op) final { \ - return 1.0 + VisitExpr(op->a); \ + return VisitExpr(op->condition) + + std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); } - VisitBinary(AddNode); VisitBinary(SubNode); VisitBinary(MulNode) - VisitBinary(DivNode); VisitBinary(ModNode); VisitBinary(FloorDivNode) - VisitBinary(FloorModNode); VisitBinary(MaxNode); VisitBinary(MinNode); - VisitBinary(EQNode); VisitBinary(NENode); VisitBinary(LTNode); - VisitBinary(LENode); VisitBinary(GTNode); VisitBinary(GENode); - VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); } +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); } + + VisitBinary(AddNode); + VisitBinary(SubNode); + VisitBinary(MulNode); + VisitBinary(DivNode); + VisitBinary(ModNode); + VisitBinary(FloorDivNode); + VisitBinary(FloorModNode); + VisitBinary(MaxNode); + VisitBinary(MinNode); + VisitBinary(EQNode); + VisitBinary(NENode); + VisitBinary(LTNode); + VisitBinary(LENode); + VisitBinary(GTNode); + VisitBinary(GENode); + VisitBinary(AndNode); + VisitBinary(OrNode); + VisitUnary(NotNode); double VisitExpr_(const CallNode* op) final { double ret = 0.0; - for (const auto&x : op->args) { + for (const auto& x : op->args) { ret += VisitExpr(x); } return ret; @@ -294,14 +299,15 @@ std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_st ss << ", "; } } - ss << " = " << "tuple(" << stage->op->name << ".op.axis)" - << " + " << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; + ss << " = " + << "tuple(" << stage->op->name << ".op.axis)" + << " + " + << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, - transform_steps); + ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); } return ss.str(); @@ -333,9 +339,10 @@ void ComputeDAG::InferBound(std::vector* states) const { auto worker_func = [&states, &out_states, this](int idx) { try { out_states[idx] = this->InferBound((*states)[idx]); - } catch (dmlc::Error &e) { - LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] - << "\n" << e.what() << std::endl; + } catch (dmlc::Error& e) { + LOG(WARNING) << "InferBound fails on the state:\n" + << (*states)[idx] << "\n" + << e.what() << std::endl; } }; @@ -358,8 +365,7 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { Map bounds; // Replay steps to tvm::Schedule - std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, - &stage_to_axes); + std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize(); // Get bound information from TVM schedule bounds = te::InferBound(sch); @@ -380,9 +386,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { auto find_res = bounds.find(axis); if (find_res != bounds.end()) { - new_iters.push_back(Iterator(iter->name, (*find_res).second, - iter->iter_type, iter->annotation, - &iter->ori_iters, iter->attr)); + new_iters.push_back(Iterator(iter->name, (*find_res).second, iter->iter_type, + iter->annotation, &iter->ori_iters, iter->attr)); } else { LOG(FATAL) << "Infer bound fails"; } @@ -445,62 +450,60 @@ std::pair > ComputeDAG::ReplaySteps( } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - std::stringstream ss; - - for (const auto& op : node->ops) { - if (op->IsInstance()) { - ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; - } else if (auto pop = op.as()) { - for (size_t k = 0; k < pop->body.size(); ++k) { - ss << op->name << "("; - for (size_t i = 0; i < pop->axis.size(); i++) { - ss << pop->axis[i]->var->name_hint; - if (i != pop->axis.size() - 1) { - ss << ", "; - } - } - ss << ")"; - if (pop->body.size() > 1) { - ss << ".v" << k; - } - if (auto preduce = pop->body[k].as()) { - CHECK_LT(k, preduce->combiner->result.size()); - PrimExpr combiner = preduce->combiner->result[k]; - if (combiner->IsInstance()) { - ss << " += " << preduce->source[0] << "\n"; - } else if (combiner->IsInstance()) { - ss << " max= " << preduce->source[0] << "\n"; - } else if (combiner->IsInstance()) { - ss << " min= " << preduce->source[0] << "\n"; - } else if (combiner->IsInstance()) { - const auto& select = combiner.as(); - ss << " select(" << select->condition << ", " << select->true_value - << ", " << select->false_value << ")= " << '(' - << preduce->source[0] << ',' << preduce->source[1] << ")\n"; - } else { - LOG(FATAL) << "Unsupported reduction operator" << combiner; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + std::stringstream ss; + + for (const auto& op : node->ops) { + if (op->IsInstance()) { + ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + } else if (auto pop = op.as()) { + for (size_t k = 0; k < pop->body.size(); ++k) { + ss << op->name << "("; + for (size_t i = 0; i < pop->axis.size(); i++) { + ss << pop->axis[i]->var->name_hint; + if (i != pop->axis.size() - 1) { + ss << ", "; + } + } + ss << ")"; + if (pop->body.size() > 1) { + ss << ".v" << k; + } + if (auto preduce = pop->body[k].as()) { + CHECK_LT(k, preduce->combiner->result.size()); + PrimExpr combiner = preduce->combiner->result[k]; + if (combiner->IsInstance()) { + ss << " += " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " max= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " min= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + const auto& select = combiner.as(); + ss << " select(" << select->condition << ", " << select->true_value << ", " + << select->false_value << ")= " << '(' << preduce->source[0] << ',' + << preduce->source[1] << ")\n"; + } else { + LOG(FATAL) << "Unsupported reduction operator" << combiner; + } + } else { + ss << " = " << pop->body[k] << "\n"; + } } } else { - ss << " = " << pop->body[k] << "\n"; + LOG(FATAL) << "Invalid op"; } } - } else { - LOG(FATAL) << "Invalid op"; - } - } - p->stream << ss.str(); -}); + p->stream << ss.str(); + }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAG") -.set_body_typed([](Array tensors) { +TVM_REGISTER_GLOBAL("ansor.ComputeDAG").set_body_typed([](Array tensors) { return ComputeDAG(tensors); }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") -.set_body_method(&ComputeDAG::GetInitState); +TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState").set_body_method(&ComputeDAG::GetInitState); TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index e07dfb8c433c..3a5089aafb1a 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -50,7 +50,7 @@ typedef std::unordered_map, ObjectHash * \param stage A `te::Stage`. * \param stage_to_axes A pointer to StageToAxesMap. */ -void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); +void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap* stage_to_axes); /*! \brief Computation declaration graph. */ class ComputeDAGNode : public Object { @@ -78,7 +78,7 @@ class ComputeDAGNode : public Object { * \brief Managed reference to ComputeDAGNode. * \sa ComputeDAGNode */ -class ComputeDAG: public ObjectRef { +class ComputeDAG : public ObjectRef { public: /*! \brief The constructor. * \param tensors `te::Tensor`s for a compute declaration. @@ -148,9 +148,9 @@ class ComputeDAG: public ObjectRef { * \param stage_to_axes A pointer to StageToAxesMap. * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. */ - std::pair > ReplaySteps( - const std::vector& transform_steps, std::vector* stages, - StageToAxesMap* stage_to_axes) const; + std::pair > ReplaySteps(const std::vector& transform_steps, + std::vector* stages, + StageToAxesMap* stage_to_axes) const; /*! * \brief Internal common parts for inferring bound. diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 7e7e4c22cfa4..46daf85c6b08 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -42,8 +42,7 @@ TVM_REGISTER_NODE_TYPE(IteratorNode); /********** Iterator **********/ Iterator::Iterator(std::string name, Range range, IteratorType iter_type, - IteratorAnnotation annotation, - const std::vector* ori_iters, + IteratorAnnotation annotation, const std::vector* ori_iters, std::string attr) { auto node = make_object(); node->name = std::move(name); @@ -64,12 +63,10 @@ Stage::Stage(te::Operation op) { node->op_type = kCompute; auto* pop = op.as(); for (const auto& axis : pop->axis) { - node->iters.push_back(Iterator(CleanName(axis->var->name_hint), - axis->dom, kSpace, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); } for (const auto& axis : pop->reduce_axis) { - node->iters.push_back(Iterator(CleanName(axis->var->name_hint), - axis->dom, kReduce, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone)); } } else if (op->IsInstance()) { node->op_type = kPlaceholder; @@ -84,9 +81,8 @@ Stage::Stage(te::Operation op) { data_ = std::move(node); } -Stage::Stage(te::Operation op, StageType op_type, - const std::vector& iters, ComputeAtType compute_at, - StageAttributes attrs) { +Stage::Stage(te::Operation op, StageType op_type, const std::vector& iters, + ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -139,13 +135,11 @@ void State::reorder(int stage_id, const std::vector& order) { } std::vector State::split(int stage_id, const Iterator& it, - const std::vector& lengths, - bool inner_to_outer) { + const std::vector& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStep(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), - lengths, inner_to_outer); + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return DoSplitStep(step); } @@ -172,9 +166,9 @@ void State::DoReorderStep(const ReorderStep& step) { } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -std::vector State::DoSplitStepCommon( - int stage_id, int iter_id, const std::vector& lengths, - bool inner_to_outer) { +std::vector State::DoSplitStepCommon(int stage_id, int iter_id, + const std::vector& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; @@ -214,21 +208,17 @@ std::vector State::DoSplitStepCommon( range = Range::FromMinExtent(tosplit_min, tosplit_extent); } if (inner_to_outer) { - outs.push_back( - Iterator(it->name + ".0", range, it->iter_type, kNone)); + outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone)); std::reverse(outs.begin(), outs.end()); } else { outs.push_back( - Iterator(it->name + "." + std::to_string(lengths.size()), - range, it->iter_type, kNone)); + Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_type, kNone)); } std::vector new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, - stage->iters.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = @@ -238,8 +228,7 @@ std::vector State::DoSplitStepCommon( } std::vector State::DoSplitStep(const SplitStep& step) { - return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, - step->inner_to_outer); + return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer); } Iterator State::DoFuseStep(const FuseStep& step) { @@ -320,31 +309,28 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { } // Print stage to ostream -void PrintStage(std::ostream* os, int stage_id, const StateNode* state, - size_t base_indent, bool delete_trivial_loop) { +void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, + bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; if (stage->attrs.auto_unroll_max_step != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->name - << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; + *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->name - << " storage_offset: " << stage->attrs.storage_offset << "\n"; + *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n"; } size_t indent = 0; for (size_t i = 0; i < stage->iters.size(); ++i) { const Iterator& iter = stage->iters[i]; - if (!(delete_trivial_loop && iter->range.defined() && - is_one(iter->range->extent))) { + if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } @@ -380,11 +366,11 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, *os << "tensorize "; break; default: - LOG(FATAL) << "Invalid Annotation " << iter->annotation; break; + LOG(FATAL) << "Invalid Annotation " << iter->annotation; + break; } if (iter->range.defined()) { - *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")"; + *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")"; } else { *os << iter->name << " (None)"; } @@ -404,8 +390,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, } // Print state to ostream -void PrintState(std::ostream* os, const StateNode* node, - bool delete_trivial_loop) { +void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { // Gather placeholders std::vector placeholders; for (const auto& stage : node->stages) { @@ -445,10 +430,10 @@ std::string State::ToStr(bool delete_trivial_loop) const { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); + }); /********** State interface API for ffi **********/ TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { @@ -464,39 +449,37 @@ TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const }); TVM_REGISTER_GLOBAL("ansor.StateReorder") -.set_body_typed([](State state, int stage_id, const Array& order) { - std::vector ord; - for (const auto& i : order) { - ord.push_back(i); - } - state.reorder(stage_id, ord); - return state; -}); + .set_body_typed([](State state, int stage_id, const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; + }); TVM_REGISTER_GLOBAL("ansor.StateSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, bool inner_to_outer) { - std::vector len; - for (const auto& i : lengths) { - len.push_back(i); - } - const auto& res = state.split(stage_id, it, len, inner_to_outer); - return Array{state, Array(res)}; -}); + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; + }); TVM_REGISTER_GLOBAL("ansor.StateFuse") -.set_body_typed([](State state, int stage_id, - const Array& iters) { - std::vector its; - for (const auto& i : iters) { - its.push_back(i); - } - const auto& res = state.fuse(stage_id, its); - return Array{state, res}; -}); + .set_body_typed([](State state, int stage_id, const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; + }); -TVM_REGISTER_GLOBAL("ansor.StateEqual") -.set_body_typed([](State state1, State state2) { +TVM_REGISTER_GLOBAL("ansor.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 721478de0b7a..9154f3b32c3d 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -135,7 +135,7 @@ class IteratorNode : public Object { v->Visit("attr", &attr); } - static constexpr const char *_type_key = "ansor.Iterator"; + static constexpr const char* _type_key = "ansor.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; @@ -154,10 +154,8 @@ class Iterator : public ObjectRef { * \param ori_iters The original iterators before fusion. * \param attr The extra attribute of this iterator. */ - Iterator(std::string name, Range range, IteratorType iter_type, - IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr, - std::string attr = ""); + Iterator(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr, std::string attr = ""); TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); }; @@ -187,11 +185,9 @@ class StageNode : public Object { /*! \brief Other stage-level attributes. */ StageAttributes attrs; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("op", &op); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); } - static constexpr const char *_type_key = "ansor.Stage"; + static constexpr const char* _type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; @@ -214,8 +210,7 @@ class Stage : public ObjectRef { * \param compute_at The compute at type of this op. * \param attrs Other stage-level attributes. */ - Stage(te::Operation op, StageType op_type, - const std::vector& iters, + Stage(te::Operation op, StageType op_type, const std::vector& iters, ComputeAtType compute_at, StageAttributes attrs); /*! * \brief The constructor. @@ -225,8 +220,7 @@ class Stage : public ObjectRef { * \param compute_at The compute at type of this op. * \param attrs Other stage-level attributes. */ - Stage(te::Operation op, StageType op_type, - std::vector&& iters, + Stage(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, StageAttributes attrs); TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); @@ -238,7 +232,7 @@ class Stage : public ObjectRef { * It consists of the current loop structure and the history steps to reach this State. * Each State corresponds to a specific schedule for the target ComputeDAG. */ -class StateNode: public Object { +class StateNode : public Object { public: /*! \brief Current stages and loop structures. */ std::vector stages; @@ -296,8 +290,7 @@ class State : public ObjectRef { * \return The iterator results after split. */ std::vector split(int stage_id, const Iterator& it, - const std::vector& lengths, - bool inner_to_outer = true); + const std::vector& lengths, bool inner_to_outer = true); /*! * \brief Schedule primitive corresponds to te.fuse. * \param stage_id The index of the target stage. @@ -363,7 +356,6 @@ class State : public ObjectRef { } // namespace ansor } // namespace tvm - // Hash and equal function for State namespace std { @@ -378,8 +370,7 @@ struct hash<::tvm::ansor::State> { /*! \brief The equal_to function for ansor::State. */ template <> struct equal_to<::tvm::ansor::State> { - bool operator() (const ::tvm::ansor::State& lhs, - const ::tvm::ansor::State& rhs) const { + bool operator()(const ::tvm::ansor::State& lhs, const ::tvm::ansor::State& rhs) const { return lhs.ToStr() == rhs.ToStr(); } }; diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index b31b993618f4..08c66cf72d36 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -73,9 +73,8 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult::BuildResult(std::string filename, Array args, - int error_no, std::string error_msg, - double time_cost) { +BuildResult::BuildResult(std::string filename, Array args, int error_no, + std::string error_msg, double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); @@ -85,9 +84,8 @@ BuildResult::BuildResult(std::string filename, Array args, data_ = std::move(node); } -MeasureResult::MeasureResult(Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { +MeasureResult::MeasureResult(Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; @@ -108,8 +106,7 @@ MeasureResult MeasureResultNode::copy() const { } /********** LocalBuilder **********/ -LocalBuilder::LocalBuilder(int timeout, int n_parallel, - const std::string& build_func) { +LocalBuilder::LocalBuilder(int timeout, int n_parallel, const std::string& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; @@ -117,11 +114,9 @@ LocalBuilder::LocalBuilder(int timeout, int n_parallel, data_ = std::move(node); } -Array LocalBuilderNode::Build(const Array& inputs, - int verbose) { +Array LocalBuilderNode::Build(const Array& inputs, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { - Array results = - (*f)(inputs, timeout, n_parallel, build_func, verbose); + Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; } else { LOG(FATAL) << "ansor.local_builder.build is not registered"; @@ -130,8 +125,8 @@ Array LocalBuilderNode::Build(const Array& inputs, } /********** LocalRunner **********/ -LocalRunner::LocalRunner(int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval) { +LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { ObjectPtr node = make_object(); node->timeout = timeout; node->number = number; @@ -141,13 +136,11 @@ LocalRunner::LocalRunner(int timeout, int number, int repeat, data_ = std::move(node); } -Array LocalRunnerNode::Run( - const Array& inputs, const Array& build_results, - int verbose) { +Array LocalRunnerNode::Run(const Array& inputs, + const Array& build_results, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { - Array results = - (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, - cooldown_interval, verbose); + Array results = (*f)(inputs, build_results, timeout, number, repeat, + min_repeat_ms, cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.local_runner.run is not registered"; @@ -156,16 +149,16 @@ Array LocalRunnerNode::Run( } /********** ProgramMeasurer **********/ -ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, - Array callbacks, int verbose, - int max_continous_error) { +ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, Array callbacks, + int verbose, int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); node->runner = std::move(runner); node->callbacks = std::move(callbacks); node->verbose = verbose; - node->max_continous_error = max_continous_error < 0 ? - ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + node->max_continous_error = max_continous_error < 0 + ? ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR + : max_continous_error; data_ = std::move(node); } @@ -176,11 +169,9 @@ void ProgramMeasurerNode::Reset() { best_state.clear(); } -void ProgramMeasurerNode::Measure(const SearchTask& task, - const SearchPolicy& policy, +void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy, const std::vector& inputs, - std::vector* results, - int batch_size) { + std::vector* results, int batch_size) { results->clear(); results->reserve(inputs.size()); @@ -189,14 +180,12 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, batch_size = builder->n_parallel * 2; } - StdCout(verbose) << "Get " << inputs.size() - << " programs for measure. (This may take a while)" + StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { - std::vector input_batch( - inputs.begin() + i, - inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector input_batch(inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); std::vector result_batch; // build and run @@ -206,8 +195,7 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, for (size_t j = 0; j < input_batch.size(); ++j) { double flops; if (result_batch[j]->error_no == 0) { - flops = - task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); error_ct = 0; } else { flops = 0.0; @@ -258,8 +246,7 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Call builder and runner Array build_res_batch = builder->Build(input_batch, verbose); - Array result_batch = - runner->Run(input_batch, build_res_batch, verbose); + Array result_batch = runner->Run(input_batch, build_res_batch, verbose); // Store result batch for (auto& res : result_batch) { @@ -269,44 +256,44 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, /********** Printing functions **********/ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "MeasureInput()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { - p->stream << "MeasureResult(cost:["; - auto old_config = p->stream.precision(4); - for (size_t i = 0; i < node->costs.size(); ++i) { - auto pf = node->costs[i].as(); - CHECK(pf != nullptr); - p->stream << pf->value; - if (i != node->costs.size() - 1) { - p->stream << ","; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; } - } - p->stream.precision(old_config); - p->stream << "], "; - p->stream << "error_no:" << 0 << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } else { - p->stream << "MeasureResult(" - << "error_type:" << ErrorNoToStr[node->error_no] << ", " - << "error_msg:" << node->error_msg << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } -}); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "BuildResult(" << node->filename << ", " << node->error_no - << ", " << node->time_cost << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no << ", " + << node->time_cost << ")"; + }); /********** Measure interface API for ffi **********/ TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) { @@ -314,38 +301,37 @@ TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, Sta }); TVM_REGISTER_GLOBAL("ansor.BuildResult") -.set_body_typed([](std::string filename, Array args, - int error_no, std::string error_msg, double time_cost) { - return BuildResult(filename, args, error_no, error_msg, time_cost); -}); + .set_body_typed([](std::string filename, Array args, int error_no, + std::string error_msg, double time_cost) { + return BuildResult(filename, args, error_no, error_msg, time_cost); + }); TVM_REGISTER_GLOBAL("ansor.MeasureResult") -.set_body_typed([](Array costs, int error_no, std::string error_msg, - double all_cost, double timestamp) { - return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); -}); + .set_body_typed([](Array costs, int error_no, std::string error_msg, double all_cost, + double timestamp) { + return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); + }); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") -.set_body_typed([](const Builder& builder, const Array& inputs, int verbose) { - return builder->Build(inputs, verbose); -}); + .set_body_typed([](const Builder& builder, const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); + }); TVM_REGISTER_GLOBAL("ansor.RunnerRun") -.set_body_typed([](const Runner& runner, const Array& inputs, - const Array& build_results, int verbose) { - return runner->Run(inputs, build_results, verbose); -}); + .set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, + int verbose) { return runner->Run(inputs, build_results, verbose); }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") -.set_body_typed([](int timeout, int n_parallel, const std::string& build_func) { - return LocalBuilder(timeout, n_parallel, build_func); -}); + .set_body_typed([](int timeout, int n_parallel, const std::string& build_func) { + return LocalBuilder(timeout, n_parallel, build_func); + }); TVM_REGISTER_GLOBAL("ansor.LocalRunner") -.set_body_typed([](int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval) { - return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); -}); + .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/measure.h b/src/ansor/measure.h index ee71ad558680..4955176aef2c 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -141,8 +141,8 @@ class BuildResult : public ObjectRef { * \param error_msg The error message if there is any error. * \param time_cost The time cost of build. */ - BuildResult(std::string filename, Array args, - int error_no, std::string error_msg, double time_cost); + BuildResult(std::string filename, Array args, int error_no, std::string error_msg, + double time_cost); TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); }; @@ -189,8 +189,8 @@ class MeasureResult : public ObjectRef { * \param all_cost The time cost of build and run. * \param timestamp The time stamps of this measurement. */ - MeasureResult(Array costs, int error_no, std::string error_msg, - double all_cost, double timestamp); + MeasureResult(Array costs, int error_no, std::string error_msg, double all_cost, + double timestamp); TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode); }; @@ -207,7 +207,7 @@ class MeasureCallbackNode : public Object { */ virtual void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) = 0; - static constexpr const char *_type_key = "ansor.MeasureCallback"; + static constexpr const char* _type_key = "ansor.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); }; @@ -265,8 +265,7 @@ class RunnerNode : public Object { * \return An Array of MeasureResult. */ virtual Array Run(const Array& inputs, - const Array& build_results, - int verbose) = 0; + const Array& build_results, int verbose) = 0; static constexpr const char* _type_key = "ansor.Runner"; TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); @@ -325,8 +324,7 @@ class LocalRunnerNode : public RunnerNode { double cooldown_interval; Array Run(const Array& inputs, - const Array& build_results, - int verbose) final; + const Array& build_results, int verbose) final; static constexpr const char* _type_key = "ansor.LocalRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); @@ -346,11 +344,9 @@ class LocalRunner : public Runner { * \param min_repeat_ms The minimum duration of one repeat in milliseconds. * \param cooldown_interval The cool down interval between two measurements. */ - LocalRunner(int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval); + LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, Runner, - LocalRunnerNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, Runner, LocalRunnerNode); }; /*! @@ -390,10 +386,8 @@ class ProgramMeasurerNode : public Object { * \param results A pointer to MeasureResult vector, this is used as output. * \param batch_size Number of programs to be measured in one batch. */ - void Measure(const SearchTask& task, - const SearchPolicy& policy, - const std::vector& inputs, - std::vector* results, + void Measure(const SearchTask& task, const SearchPolicy& policy, + const std::vector& inputs, std::vector* results, int batch_size = -1); /*! * \brief Do measurement silently. @@ -402,8 +396,7 @@ class ProgramMeasurerNode : public Object { * \param inputs The target MeasureInputs. * \param results A pointer to MeasureResult vector, this is used as output. */ - void SilentMeasure(const SearchTask& task, - const std::vector& inputs, + void SilentMeasure(const SearchTask& task, const std::vector& inputs, std::vector* results); /*! \brief The default max continuous error setting. */ @@ -427,9 +420,8 @@ class ProgramMeasurer : public ObjectRef { * \param verbose Verbose level. * \param max_continous_error The number of max continuous error. */ - ProgramMeasurer(Builder builder, Runner runner, - Array callbacks, - int verbose, int max_continous_error = -1); + ProgramMeasurer(Builder builder, Runner runner, Array callbacks, int verbose, + int max_continous_error = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); }; diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 5f43f5352695..f507aa98e22c 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -70,7 +70,7 @@ class SearchCallbackNode : public Object { */ virtual void Callback(SearchPolicyNode* policy) = 0; - static constexpr const char *_type_key = "ansor.SearchCallback"; + static constexpr const char* _type_key = "ansor.SearchCallback"; TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); }; @@ -123,7 +123,7 @@ class SearchPolicyNode : public Object { */ void RunCallbacks(const Array& callbacks); - static constexpr const char *_type_key = "ansor.SearchPolicy"; + static constexpr const char* _type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); protected: diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index e7ea6eb05e90..c64b919e008f 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -37,9 +37,8 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(HardwareParamsNode); TVM_REGISTER_NODE_TYPE(SearchTaskNode); -HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor) { +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor) { auto node = make_object(); node->num_cores = num_cores; node->vector_unit_bytes = vector_unit_bytes; @@ -49,8 +48,8 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, data_ = std::move(node); } -HardwareParams HardwareParamsNode::GetDefaultHardwareParams( - const Target& target, const Target& target_host) { +HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, + const Target& target_host) { if (target->target_name == "llvm") { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 64, 64); } else { @@ -59,9 +58,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return HardwareParams(); } -SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, - HardwareParams hardware_params) { +SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, Target target, + Target target_host, HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -70,27 +68,24 @@ SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, if (hardware_params.defined()) { node->hardware_params = std::move(hardware_params); } else { - node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams( - node->target, node->target_host); + node->hardware_params = + HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } data_ = std::move(node); } TVM_REGISTER_GLOBAL("ansor.HardwareParams") -.set_body_typed([](int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor) { - return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, - max_unroll_vec, max_innermost_split_factor); -}); + .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor) { + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, + max_innermost_split_factor); + }); TVM_REGISTER_GLOBAL("ansor.SearchTask") -.set_body_typed([](ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, - HardwareParams hardware_params) { - return SearchTask(compute_dag, workload_key, target, target_host, - hardware_params); -}); + .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, Target target, + Target target_host, HardwareParams hardware_params) { + return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 16601bc09516..351a51124e7e 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -82,8 +82,7 @@ class HardwareParamsNode : public Object { * \param target_host A `tvm.target` for host device. * \return A HardwareParams object. */ - static HardwareParams GetDefaultHardwareParams(const Target& target, - const Target& target_host); + static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); @@ -103,8 +102,8 @@ class HardwareParams : public ObjectRef { * \param max_unroll_vec The max length of an axis to be unrolled or vectorized. * \param max_innermost_split_factor The max split factor for the innermost tile. */ - HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, - int max_unroll_vec, int max_innermost_split_factor); + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor); TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); @@ -150,8 +149,7 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. */ - SearchTask(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, + SearchTask(ComputeDAG compute_dag, std::string workload_key, Target target, Target target_host, HardwareParams hardware_params); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index b6dafdc80625..62c5cd2c4033 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -46,7 +46,7 @@ namespace json { inline std::vector& IntArrayToVector(std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto&x : data) { + for (const auto& x : data) { auto pi = x.as<::tvm::tir::IntImmNode>(); CHECK(pi != nullptr) << "Can only contain int values"; out->push_back(pi->value); @@ -63,12 +63,13 @@ struct Handler> { inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Stage>* data) { bool s; reader->BeginArray(); - s = reader->NextArrayItem(); CHECK(!s); + s = reader->NextArrayItem(); + CHECK(!s); } }; template <> -struct Handler > { +struct Handler> { inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Step>& data) { std::vector tmp; writer->BeginArray(false); @@ -124,110 +125,123 @@ struct Handler > { data->clear(); while (reader->NextArrayItem()) { reader->BeginArray(); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&name); if (name == "RE") { - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&int_list); data->push_back(::tvm::ansor::ReorderStep(stage_id, int_list)); } else if (name == "SP") { - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&extent); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&int_list); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&inner_to_outer); data->push_back(::tvm::ansor::SplitStep( stage_id, iter_id, extent, - std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), - inner_to_outer)); + std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); } else if (name == "FU") { - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&int_list); data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); } else { LOG(FATAL) << "Invalid step format"; } - s = reader->NextArrayItem(); CHECK(!s); + s = reader->NextArrayItem(); + CHECK(!s); } } }; template <> struct Handler<::tvm::ansor::StateNode> { - inline static void Write(dmlc::JSONWriter* writer, - const ::tvm::ansor::StateNode& data) { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::StateNode& data) { writer->BeginArray(false); writer->WriteArrayItem(data.stages); writer->WriteArrayItem(data.transform_steps); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, - ::tvm::ansor::StateNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::StateNode* data) { reader->BeginArray(); bool s; - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->stages); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->transform_steps); - s = reader->NextArrayItem(); CHECK(!s); + s = reader->NextArrayItem(); + CHECK(!s); } }; template <> struct Handler<::tvm::ansor::SearchTaskNode> { - inline static void Write(dmlc::JSONWriter* writer, - const ::tvm::ansor::SearchTaskNode& data) { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::SearchTaskNode& data) { writer->BeginArray(false); writer->WriteArrayItem(data.workload_key); writer->WriteArrayItem(data.target->str()); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, - ::tvm::ansor::SearchTaskNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::SearchTaskNode* data) { std::string target_str; bool s; reader->BeginArray(); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->workload_key); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&target_str); data->target = ::tvm::Target::Create(target_str); - s = reader->NextArrayItem(); CHECK(!s); + s = reader->NextArrayItem(); + CHECK(!s); } }; template <> struct Handler<::tvm::ansor::MeasureInputNode> { - inline static void Write(dmlc::JSONWriter* writer, - const ::tvm::ansor::MeasureInputNode& data) { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::MeasureInputNode& data) { writer->BeginArray(false); writer->WriteArrayItem(*data.task.operator->()); writer->WriteArrayItem(*data.state.operator->()); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, - ::tvm::ansor::MeasureInputNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::MeasureInputNode* data) { bool s; auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>(); auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>(); state_node->complete = true; reader->BeginArray(); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(task_node.get()); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(state_node.get()); - s = reader->NextArrayItem(); CHECK(!s); + s = reader->NextArrayItem(); + CHECK(!s); data->task = ::tvm::ansor::SearchTask(task_node); data->state = ::tvm::ansor::State(state_node); @@ -236,12 +250,11 @@ struct Handler<::tvm::ansor::MeasureInputNode> { template <> struct Handler<::tvm::ansor::MeasureResultNode> { - inline static void Write(dmlc::JSONWriter* writer, - const ::tvm::ansor::MeasureResultNode& data) { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::MeasureResultNode& data) { writer->BeginArray(false); writer->WriteArraySeperator(); writer->BeginArray(false); - for (const auto&x : data.costs) { + for (const auto& x : data.costs) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; writer->WriteArrayItem(pf->value); @@ -252,25 +265,29 @@ struct Handler<::tvm::ansor::MeasureResultNode> { writer->WriteArrayItem(static_cast((data.timestamp))); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, - ::tvm::ansor::MeasureResultNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::MeasureResultNode* data) { bool s; std::vector tmp; reader->BeginArray(); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&tmp); data->costs.clear(); for (const auto& i : tmp) { data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); } - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->error_no); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->all_cost); - s = reader->NextArrayItem(); CHECK(s); + s = reader->NextArrayItem(); + CHECK(s); reader->Read(&data->timestamp); - s = reader->NextArrayItem(); CHECK(!s); + s = reader->NextArrayItem(); + CHECK(!s); } }; @@ -283,7 +300,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) +const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) LogToFile::LogToFile(std::string filename) { auto node = make_object(); @@ -291,8 +308,7 @@ LogToFile::LogToFile(std::string filename) { data_ = std::move(node); } -void WriteMeasureRecords(std::ostream* os, - const Array& inputs, +void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { @@ -305,9 +321,7 @@ void WriteMeasureRecords(std::ostream* os, } } -void ReadMeasureRecord(const std::string& str, - MeasureInputNode* inp, - MeasureResultNode* res, +void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version) { std::istringstream ss(str); dmlc::JSONReader reader(&ss); @@ -340,9 +354,7 @@ LogReader::LogReader(std::string filename) { data_ = std::move(node); } -LogReaderNode::~LogReaderNode() { - infile.close(); -} +LogReaderNode::~LogReaderNode() { infile.close(); } bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { std::string log_version; @@ -359,8 +371,8 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { return false; } -std::pair, Array > LogReaderNode::ReadLines( - int max_size, int skip_size) { +std::pair, Array> LogReaderNode::ReadLines(int max_size, + int skip_size) { auto inp = make_object(); auto res = make_object(); Array inputs; @@ -392,13 +404,12 @@ TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const std::string& file }); TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") -.set_body_typed([](LogReader reader, int size, int skip_size) { - const auto& res = reader->ReadLines(size, skip_size); - return Array{res.first, res.second}; -}); + .set_body_typed([](LogReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; + }); -TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") -.set_body_typed([](LogReader reader) { +TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext").set_body_typed([](LogReader reader) { auto inp = make_object(); auto res = make_object(); if (reader->ReadNext(inp.get(), res.get())) { @@ -408,8 +419,7 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") } }); -TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile").set_body([](TVMArgs args, TVMRetValue* ret) { std::string filename = args[0]; Array in = args[1]; Array res = args[2]; @@ -418,57 +428,54 @@ TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") }); TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Array inputs = args[0]; - SearchTask external_task; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Array inputs = args[0]; + SearchTask external_task; - if (args.size() > 1) { - external_task = args[1]; - } + if (args.size() > 1) { + external_task = args[1]; + } - Array states; - states.reserve(inputs.size()); + Array states; + states.reserve(inputs.size()); - // (workload_key, target) -> (search_task) - std::unordered_map, SearchTask> task_cache; + // (workload_key, target) -> (search_task) + std::unordered_map, SearchTask> task_cache; - for (const auto& inp : inputs) { - const std::string& workload_key = inp->task->workload_key; - std::pair key(workload_key, inp->task->target->str()); + for (const auto& inp : inputs) { + const std::string& workload_key = inp->task->workload_key; + std::pair key(workload_key, inp->task->target->str()); - const SearchTaskNode* ptask; - if (external_task.defined()) { - ptask = external_task.operator->(); - } else { - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - if (inp->task->compute_dag.defined()) { // the measure input is complete - ptask = inp->task.operator->(); - } else { // the measure input is incomplete - // rebuild task for incomplete measure pairs read from file - SearchTask new_task = SearchTask( - ComputeDAG(workload_key), - workload_key, - inp->task->target, - inp->task->target_host, - inp->task->hardware_params); - task_cache.insert(std::make_pair(key, new_task)); - ptask = new_task.operator->(); + const SearchTaskNode* ptask; + if (external_task.defined()) { + ptask = external_task.operator->(); + } else { + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inp->task->compute_dag.defined()) { // the measure input is complete + ptask = inp->task.operator->(); + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + SearchTask new_task = + SearchTask(ComputeDAG(workload_key), workload_key, inp->task->target, + inp->task->target_host, inp->task->hardware_params); + task_cache.insert(std::make_pair(key, new_task)); + ptask = new_task.operator->(); + } + } else { + ptask = find_res->second.operator->(); + } } - } else { - ptask = find_res->second.operator->(); - } - } - State tmp_s = ptask->compute_dag.GetInitState(); - StateNode *ps = tmp_s.CopyOnWrite(); - ps->transform_steps = inp->state->transform_steps; - tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag); - states.push_back(std::move(tmp_s)); - } + State tmp_s = ptask->compute_dag.GetInitState(); + StateNode* ps = tmp_s.CopyOnWrite(); + ps->transform_steps = inp->state->transform_steps; + tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag); + states.push_back(std::move(tmp_s)); + } - *ret = states; -}); + *ret = states; + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index f8ab6b42dda2..53ea88323cf9 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -43,7 +43,7 @@ class LogToFileNode : public MeasureCallbackNode { void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; - static constexpr const char *_type_key = "ansor.LogToFile"; + static constexpr const char* _type_key = "ansor.LogToFile"; TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; @@ -85,8 +85,8 @@ class LogReaderNode : public Object { * \param skip_size Skip the first n lines. * \return The MeasureInputs and MeasureResults loaded from the log file. */ - std::pair, Array > ReadLines( - int max_size = -1, int skip_size = 0); + std::pair, Array > ReadLines(int max_size = -1, + int skip_size = 0); static constexpr const char* _type_key = "ansor.LogReader"; TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); @@ -117,8 +117,7 @@ class LogReader : public ObjectRef { * \param inputs The target MeasureInputs to be written. * \param results The target MeasureResults to be written. */ -void WriteMeasureRecords(std::ostream* os, - const Array& inputs, +void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results); /*! @@ -128,9 +127,7 @@ void WriteMeasureRecords(std::ostream* os, * \param res A pointer to MeasureResultNode, this is used as output. * \param log_version A pointer to log version string. */ -void ReadMeasureRecord(const std::string& str, - MeasureInputNode* inp, - MeasureResultNode* res, +void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version); } // namespace ansor diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 193a11ffd191..10ef9e3d6ab0 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -19,7 +19,8 @@ /*! * \file ansor/transform_step.cc - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform + * step. */ #include "transform_step.h" @@ -43,8 +44,8 @@ ReorderStep::ReorderStep(int stage_id, const std::vector& after_ids) { data_ = std::move(node); } -void ReorderStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { +void ReorderStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes) const { te::Stage& stage = (*stages)[stage_id]; const std::vector& axes = (*stage_to_axes)[stage]; CHECK_EQ(after_ids.size(), axes.size()); @@ -58,9 +59,8 @@ void ReorderStepNode::ApplyToSchedule(std::vector *stages, (*stage_to_axes)[stage] = std::move(new_axes); } -std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, +std::string ReorderStepNode::PrintAsPythonAPI(std::vector* stages, + StageToAxesMap* stage_to_axes, te::Schedule* schedule, const std::vector& transform_steps) const { const te::Stage& stage = (*stages)[stage_id]; std::stringstream ss; @@ -79,10 +79,8 @@ std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Split **********/ -std::vector ApplySplitToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - int stage_id, - int iter_id, +std::vector ApplySplitToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes, int stage_id, int iter_id, const std::vector& lengths, bool inner_to_outer) { te::Stage& stage = (*stages)[stage_id]; @@ -120,36 +118,29 @@ std::vector ApplySplitToSchedule(std::vector *stages, return outs; } -std::string PrintSplitAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - int stage_id, - int iter_id, - const std::vector& lengths, +std::string PrintSplitAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, + int stage_id, int iter_id, const std::vector& lengths, bool inner_to_outer) { te::Stage& stage = (*stages)[stage_id]; auto to_split = (*stage_to_axes)[stage][iter_id]; const auto& func_name = CleanName(stage->op->name); - const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, - iter_id, lengths, inner_to_outer); + const auto& outs = + ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); std::stringstream ss; int size = static_cast(lengths.size()); if (inner_to_outer) { for (int i = size - 1; i >= 0; i--) { ss << CleanName(outs[size - i]->var->name_hint) << ", " - << CleanName(outs[size - i - 1]->var->name_hint) - << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) - << ", factor=" << lengths[i] << ")\n"; + << CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n"; to_split = outs[size - i]; } } else { for (int i = 0; i < size; i++) { - ss << CleanName(outs[i]->var->name_hint) << ", " - << CleanName(outs[i + 1]->var->name_hint) - << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) - << ", nparts=" << lengths[i] << ")\n"; + ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; to_split = outs[i + 1]; } } @@ -158,8 +149,7 @@ std::string PrintSplitAsPythonAPI(std::vector *stages, } SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, - bool inner_to_outer) { + const std::vector& lengths, bool inner_to_outer) { auto node = make_object(); node->stage_id = stage_id; // Extent can be a unreducible expression in some special cases @@ -172,17 +162,15 @@ SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, data_ = std::move(node); } -std::vector SplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes) const { - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, inner_to_outer); +std::vector SplitStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -std::string SplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, inner_to_outer); +std::string SplitStepNode::PrintAsPythonAPI(std::vector* stages, + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const std::vector& transform_steps) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } /********** Fuse **********/ @@ -193,8 +181,8 @@ FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { data_ = std::move(node); } -IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { +IterVar FuseStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes) const { te::Stage& stage = (*stages)[stage_id]; const std::vector& axes = (*stage_to_axes)[stage]; @@ -207,16 +195,14 @@ IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, std::vector new_axes; new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, - axes.end()); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); (*stage_to_axes)[stage] = std::move(new_axes); return fused_axis; } -std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, +std::string FuseStepNode::PrintAsPythonAPI(std::vector* stages, + StageToAxesMap* stage_to_axes, te::Schedule* schedule, const std::vector& transform_steps) const { const auto& stage = (*stages)[stage_id]; std::stringstream to_fuse; @@ -231,8 +217,7 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, std::stringstream ss; const auto& fused = ApplyToSchedule(stages, stage_to_axes); - ss << CleanName(fused->var->name_hint) << " = s[" - << CleanName(stage->op->name) << "].fuse(" + ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" << to_fuse.str() << ")\n"; return ss.str(); diff --git a/src/ansor/utils.h b/src/ansor/utils.h index cad27d51ba7e..b43edfc10527 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -27,25 +27,26 @@ #include #include -#include -#include + +#include #include -#include #include -#include +#include +#include #include -#include -#include +#include #include -#include +#include #include -#include +#include +#include +#include namespace std { /*! \brief Hash function for std::pair */ template -struct hash > { +struct hash> { std::size_t operator()(const std::pair& k) const { return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); } @@ -53,7 +54,7 @@ struct hash > { /*! \brief Hash function for std::tuple */ template -struct hash > { +struct hash> { std::size_t operator()(const std::tuple& k) const { return ::dmlc::HashCombine( ::dmlc::HashCombine(std::hash()(std::get<0>(k)), std::hash()(std::get<1>(k))), @@ -63,7 +64,7 @@ struct hash > { /*! \brief Hash function for std::vector */ template -struct hash > { +struct hash> { std::size_t operator()(const std::vector& vec) const { if (vec.empty()) { return 0; @@ -84,8 +85,7 @@ namespace ansor { /********** Utilities for std::vector, std::set, std::string **********/ /*! \brief Get the first appearance index of elements in a vector */ template -inline void GetIndices(const std::vector& array, - const std::vector& to_locate, +inline void GetIndices(const std::vector& array, const std::vector& to_locate, std::vector* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); @@ -132,10 +132,10 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st inline double FloatArrayMean(const Array& float_array) { double sum = 0; if (float_array.empty()) { - return 0.0; + return 0.0; } - for (const auto&x : float_array) { + for (const auto& x : float_array) { auto floatimm = x.as(); CHECK(floatimm != nullptr); sum += floatimm->value; @@ -182,7 +182,7 @@ inline std::string CleanName(const std::string& str) { class NullStream : public std::ostream { public: NullStream() : std::ostream(nullptr) {} - NullStream(const NullStream &) : std::ostream(nullptr) {} + NullStream(const NullStream&) : std::ostream(nullptr) {} static NullStream& Global(); }; @@ -203,7 +203,7 @@ inline std::ostream& StdCout(int verbose) { /*! \brief Print a title */ inline void PrintTitle(const std::string& title, int verbose) { if (verbose >= 1) { - std::cout << "------------------------------------------------------------" << "\n"; + std::cout << "------------------------------------------------------------\n"; std::cout << "----------------------- [ " << title << " ]\n"; std::cout << "------------------------------------------------------------" << std::endl; } @@ -214,7 +214,7 @@ class ThreadPool { public: void Launch(size_t n = 1) { for (std::size_t i = 0; i < n; ++i) { - threads_.emplace_back([this] {WorkerFunc();}); + threads_.emplace_back([this] { WorkerFunc(); }); } } @@ -223,7 +223,7 @@ class ThreadPool { is_finished_ = n <= 0; } - template::type> + template ::type> std::future Enqueue(F&& f, Args&&... args) { std::packaged_task p(std::bind(f, args...)); @@ -267,16 +267,12 @@ class ThreadPool { threads_.clear(); } - size_t NumWorkers() { - return threads_.size(); - } + size_t NumWorkers() { return threads_.size(); } static const int REFRESH_EVERY = 128; static ThreadPool& Global(); - ~ThreadPool() { - Join(); - } + ~ThreadPool() { Join(); } private: void WorkerFunc() { @@ -285,12 +281,14 @@ class ThreadPool { { std::unique_lock l(m_); if (work_.empty()) { - work_signal_.wait(l, [&]{ return !work_.empty(); }); + work_signal_.wait(l, [&] { return !work_.empty(); }); } f = std::move(work_.front()); work_.pop_front(); } - if (!f.valid()) { return; } + if (!f.valid()) { + return; + } f(); finish_ct_--; From 907c17c8a03b0d7436ac2d15da082ea9f68361fa Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 1 Jul 2020 15:07:52 +0800 Subject: [PATCH 53/78] Update --- .gitignore | 3 - python/tvm/ansor/__init__.py | 6 +- python/tvm/ansor/auto_schedule.py | 91 ++++++------ python/tvm/ansor/compute_dag.py | 35 ++++- python/tvm/ansor/loop_state.py | 33 +++-- python/tvm/ansor/measure.py | 67 +++++---- python/tvm/ansor/serialization.py | 42 +++--- python/tvm/ansor/utils.py | 40 +++++- python/tvm/ansor/workload_registry.py | 129 ++---------------- src/ansor/serialization.cc | 15 +- tests/python/unittest/test_ansor_common.py | 2 +- tests/python/unittest/test_ansor_measure.py | 2 +- .../unittest/test_ansor_search_policy.py | 4 +- 13 files changed, 220 insertions(+), 249 deletions(-) diff --git a/.gitignore b/.gitignore index 3c03e8ecda7a..506e54d93067 100644 --- a/.gitignore +++ b/.gitignore @@ -234,6 +234,3 @@ conda/pkg # antlr files *.tokens *.interp - -# ansor tuning logs -scripts/*.json diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 93a82f073ac3..2368cfd8489a 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -30,6 +30,6 @@ auto_schedule, EmptyPolicy from .measure import MeasureInput, LocalBuilder, LocalRunner from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ - load_from_file, write_measure_records_to_file -from .workload_registry import register_workload_func, \ - workload_key_to_dag, make_workload_key_func + load_from_file, append_measure_records_to_file +from .workload_registry import register_workload, \ + workload_key_to_dag, make_workload_key_by_func diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 7a9d7c322c9e..8b1a2c14a5c3 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -60,9 +60,9 @@ class SearchTask(Object): The workload key for target compute declaration. target : tvm.target.Target The target device of this search task. - target_host : tvm.target.Target + target_host : Optional[tvm.target.Target] The target host device of this search task. - hardware_params : HardwareParams + hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. """ def __init__(self, dag, workload_key, target, target_host=None, @@ -88,33 +88,41 @@ def __init__(self): @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): - """ The options for tuning + """ The options for tuning. Parameters ---------- - n_trials: int - Number of total measurement trials - early_stopping: int - Stops early the tuning if no improvement after n measurements - num_measure_per_round: int - The number of programs to be measured at each iteration - verbose: int + n_trials: int = 1 + The number of total schedule measure trials. + Ansor takes `n_trials` state for measuring in total, and finally gets the best schedule + among them. + With `n_trials` == 1, Ansor will do the schedule search but don't involve measurement, + this can be used if we want to quickly get a runnable schedule without performance tuning. + early_stopping: int = -1 + Stops early the tuning if no improvement get after n measurements. + num_measure_per_round: int = 64 + The number of programs to be measured at each search round. + The whole schedule search process is designed to have several rounds to try a total + `n_trials` schedules. + We have: `num_search_rounds` = `n_trials` // `num_measure_per_round` + verbose: int = 1 Verbosity level. 0 means silent. - builder: Builder - Builder which builds the program - runner: Runner - Runner which runs the program and measure time costs - measure_callbacks: List[MeasureCallback] - Callback functions called after each measure + builder: Union[Builder, str] = 'local' + Builder which builds the program. + runner: Union[Runner, str] = 'local' + Runner which runs the program and measures time costs. + measure_callbacks: Optional[List[MeasureCallback]] + Callback functions called after each measure. Candidates: - ansor.LogToFile - pre_search_callbacks: List[SearchCallback] - Callback functions called before the search process + pre_search_callbacks: Optional[List[SearchCallback]] + Callback functions called before the search process. Candidates: - - ansor.PreloadMeasuredStates(will be added later) - - ansor.PreloadCustomSketchRule(will be added later) + - ansor.PreloadMeasuredStates + - ansor.PreloadCustomSketchRule + TODO(jcf94): Add these implementation in later PRs. """ - def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_round=64, + def __init__(self, n_trials=1, early_stopping=-1, num_measure_per_round=64, verbose=1, builder='local', runner='local', measure_callbacks=None, pre_search_callbacks=None): if isinstance(builder, str): @@ -127,40 +135,36 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_round=64, if runner == 'local': runner = LocalRunner() else: - raise ValueError("Invalid builder: " + runner) + raise ValueError("Invalid runner: " + runner) - if measure_callbacks is None: - measure_callbacks = [] - - if pre_search_callbacks is None: - pre_search_callbacks = [] + measure_callbacks = [] if measure_callbacks is None else measure_callbacks + pre_search_callbacks = [] if pre_search_callbacks is None else pre_search_callbacks self.__init_handle_by_constructor__( _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks) -def auto_schedule(workload, target=None, - target_host=None, search_policy='default', +def auto_schedule(task, target, target_host=None, search_policy='default', hardware_params=None, tune_option=None): """ Do auto scheduling for a computation declaration. - The workload parameter can be a `string` as workload_key, or directly + The task parameter can be a `string` as workload_key, or directly passing a `SearchTask` as input. Parameters ---------- - workload : Union[SearchTask, str] + task : Union[SearchTask, str] The target search task or workload key. - target : Target + target : tvm.target.Target The target device of this schedule search. - target_host : Target = None + target_host : Optional[tvm.target.Target] The target host device of this schedule search. - search_policy : Union[SearchPolicy, str] + search_policy : Union[SearchPolicy, str] = 'default' The search policy to be used for schedule search. - hardware_params : HardwareParams + hardware_params : Optional[HardwareParams] The hardware parameters of this schedule search. - tune_option : TuneOption + tune_option : Optional[TuneOption] Tuning and measurement options. Returns @@ -169,18 +173,19 @@ def auto_schedule(workload, target=None, """ if isinstance(search_policy, str): if search_policy == 'default': + # TODO(jcf94): This is an example policy for minimum system, will be upgrated to + # formal search policy later. search_policy = EmptyPolicy() else: raise ValueError("Invalid search policy: " + search_policy) - if tune_option is None: - tune_option = TuneOption(n_trials=0) + tune_option = TuneOption() if tune_option is None else tune_option - if isinstance(workload, str): + if isinstance(task, str): sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( - workload, target, target_host, search_policy, hardware_params, tune_option) + task, target, target_host, search_policy, hardware_params, tune_option) return sch, tensors - if isinstance(workload, SearchTask): - sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) + if isinstance(task, SearchTask): + sch, tensors = _ffi_api.AutoScheduleBySearchTask(task, search_policy, tune_option) return sch, tensors - raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") + raise ValueError("Invalid task: " + task + ". Expect a string or SearchTask") diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 1e289aaafe0c..dbac298a3f92 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -17,9 +17,15 @@ """ Computational graph and its analysis tools """ +import hashlib + import tvm._ffi from tvm.runtime import Object +from tvm.te import PlaceholderOp, ComputeOp + from .loop_state import State, StateObject +from .utils import get_const_tuple + from . import _ffi_api @@ -42,6 +48,7 @@ def get_init_state(self): Returns ------- state : State + The initial State without any transform steps. """ return State(_ffi_api.ComputeDAGGetInitState(self), self) @@ -51,7 +58,7 @@ def apply_steps_from_state(self, state): Parameters ---------- - state : StateObject or State + state : Union[State, StateObject] The target state to be applied to TVM schedule. Returns @@ -67,12 +74,13 @@ def print_python_code_from_state(self, state): Parameters ---------- - state : StateObject or State + state : Union[State, StateObject] The target state to be applied to TVM schedule. Returns ------- str : Str + The Python schedule code. """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) @@ -92,12 +100,33 @@ def infer_bound_from_state(self, state): Parameters ---------- - state : StateObject + state : Union[State, StateObject] The target state to be applied to TVM schedule. Returns ------- state : State + The State with complete bound information. """ state_obj = state if isinstance(state, StateObject) else state.state_object return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + + def __hash__(self): + # TODO(...): Implement this more carefully and move this to c++ as a member function + # of ComputeDAG + str_key = '' + for op in self.ops: + t = op.output(0) + if isinstance(op, PlaceholderOp): + str_key += 'placeholder,' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + elif isinstance(op, ComputeOp): + str_key += str(t.op.body) + ',' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + else: + raise ValueError("Invalid op: " + op) + + str_key = str_key.encode(encoding='utf-8') + return hashlib.md5(str_key).hexdigest() diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index a796373b9393..7ed32e477523 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -121,33 +121,35 @@ def transform_steps_size(self): """ return _ffi_api.StateGetTransformStepsSize(self.state_object) - def reorder(self, stage_id, order): + def reorder(self, stage, order): """ Schedule primitive corresponds to te.reorder. Parameters ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to reorder + stage : Union[int, Operation, Tensor] + The target Stage to be reordered, can be a Stage order index, Stage operation or stage + output tensor. order : List[Iterator] Iterators in the expected order """ - stage_id = self._resolve_stage_id(stage_id) + stage_id = self._resolve_stage_id(stage) self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) self._clear_cache() - def split(self, stage_id, iterator, lengths, inner_to_outer=True): + def split(self, stage, iterator, lengths, inner_to_outer=True): """ Schedule primitive corresponds to te.split. Parameters ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to split + stage : Union[int, Operation, Tensor] + The target Stage to be split, can be a Stage order index, Stage operation or stage + output tensor. iterator : Iterator The iterator to split lengths: List[int] The split factors - inner_to_outer: bool + inner_to_outer: bool = True True to use `factor` to split from inner to outer, False to use `nparts` to split from outer to inner @@ -156,20 +158,21 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ - stage_id = self._resolve_stage_id(stage_id) + stage_id = self._resolve_stage_id(stage) self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) self._clear_cache() return res - def fuse(self, stage_id, iters): + def fuse(self, stage, iters): """ Schedule primitive corresponds to te.fuse. Parameters ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to fuse + stage : Union[int, Operation, Tensor] + The target Stage to be reordered, can be a Stage order index, Stage operation or stage + output tensor. iters : List[Iterator] The iterators to be fused @@ -178,7 +181,7 @@ def fuse(self, stage_id, iters): res_it : Iterator The fused Iterator """ - stage_id = self._resolve_stage_id(stage_id) + stage_id = self._resolve_stage_id(stage) self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) self._clear_cache() @@ -193,7 +196,7 @@ def copy(self): def _resolve_stage_id(self, stage_id): if isinstance(stage_id, Operation): return self.stage_id_map[stage_id] - if isinstance(stage_id, tvm.te.Tensor): + if isinstance(stage_id, Tensor): return self.stage_id_map[stage_id.op] if isinstance(stage_id, int): return stage_id @@ -215,7 +218,7 @@ def __getitem__(self, key): key = key.op if isinstance(key, Operation): return self.stages_cache[self.stage_id_map[key]] - raise ValueError("Item must be Tensor") + raise ValueError("Item must be Tensor or Operation") def __str__(self): return str(self.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 06e7c6fa1a4e..5bb0a58f37aa 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -42,6 +42,9 @@ # The maximum length of error message MAX_ERROR_MSG_LEN = 512 +# Global variables used in build function +GLOBAL_BUILD_ARGUMENTS = None + @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): """ Base class for measurement callback function. """ @@ -68,21 +71,23 @@ class BuildResult(Object): Parameters ---------- - filename : Str + filename : Optional[str] The filename of built binary file. args : List[Tensor] The arguments. - error_no : Int + error_no : int The error code. - error_msg : Str + error_msg : Optional[str] The error message if there is any error. - time_cost : Float + time_cost : float The time cost of build. """ def __init__(self, filename, args, error_no, error_msg, time_cost): + filename = filename if filename else "" + error_msg = error_msg if error_msg else "" + self.__init_handle_by_constructor__( - _ffi_api.BuildResult, filename if filename else "", args, error_no, - error_msg if error_msg else "", time_cost) + _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost) @tvm._ffi.register_object("ansor.MeasureResult") @@ -91,21 +96,23 @@ class MeasureResult(Object): Parameters ---------- - costs : List[Float] + costs : List[float] The time costs of execution. - error_no : Int + error_no : int The error code. - error_msg : Str + error_msg : Optional[str] The error message if there is any error. - all_cost : Float + all_cost : float The time cost of build and run. - timestamp : Float + timestamp : float The time stamps of this measurement. """ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): + error_msg = error_msg if error_msg else "" + self.__init_handle_by_constructor__( _ffi_api.MeasureResult, costs, error_no, - error_msg if error_msg else "", all_cost, timestamp) + error_msg, all_cost, timestamp) @tvm._ffi.register_object("ansor.Builder") @@ -119,7 +126,7 @@ def build(self, measure_inputs, verbose=1): ---------- measure_inputs : List[MeasureInput] A List of MeasureInput. - verbost : Int + verbost : int = 1 Verbosity level. (0 means silent) Returns @@ -142,6 +149,8 @@ def run(self, measure_inputs, build_results, verbose=1): A List of MeasureInput. build_results : List[BuildResult] A List of BuildResult to be ran. + verbost : int = 1 + Verbosity level. (0 means silent) Returns ------- @@ -156,11 +165,11 @@ class LocalBuilder(Builder): Parameters ---------- - timeout : Int + timeout : int = 15 The timeout limit for each build. - n_parallel : Int + n_parallel : int = multiprocessing.cpu_count() Number of threads used to build in parallel. - build_func : Str + build_func : str = 'default' The name of registered build function. """ @@ -178,15 +187,15 @@ class LocalRunner(Runner): Parameters ---------- - timeout : Int + timeout : int = 10 The timeout limit for each run. - number : Int + number : int = 3 Number of measure times. - repeat : Int + repeat : int = 1 Number of repeat times in each measure. - min_repeat_ms : Int + min_repeat_ms : int = 0 The minimum duration of one repeat in milliseconds. - cooldown_interval : Float + cooldown_interval : float = 0.0 The cool down interval between two measurements. """ @@ -224,16 +233,15 @@ def make_error_msg(): return error_msg -GLOBAL_BUILD_ARGUMENTS = None -GLOBAL_RUN_ARGUMENTS = None - - def local_build_worker(index): """ Local builder function. """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + if not GLOBAL_BUILD_ARGUMENTS: + raise ValueError("GLOBAL_BUILD_ARGUMENTS not found") measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS assert isinstance(build_func, str) + if build_func == 'default': build_func = tar.tar elif build_func == 'ndk': @@ -253,7 +261,7 @@ def timed_func(): try: sch, args = task.compute_dag.apply_steps_from_state( inp.state) - # pylint: disable=W0703 + # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR error_msg = make_error_msg() @@ -268,7 +276,7 @@ def timed_func(): func = build_module.build( sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) - # pylint: disable=W0703 + # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.COMPILE_HOST error_msg = make_error_msg() @@ -326,7 +334,7 @@ def timed_func(inp, build_res): ctx = ndarray.context(str(inp.task.target), 0) time_f = func.time_evaluator( func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) - # pylint: disable=W0703 + # pylint: disable=broad-except except Exception: costs = (max_float,) error_no = MeasureErrorNo.COMPILE_DEVICE @@ -337,9 +345,8 @@ def timed_func(inp, build_res): args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() - costs = time_f(*args).results - # pylint: disable=W0703 + # pylint: disable=broad-except except Exception: costs = (max_float,) error_no = MeasureErrorNo.RUNTIME_DEVICE diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 9db85dc98ef9..8c8723ffbbaa 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -32,7 +32,7 @@ class LogToFile(MeasureCallback): Parameters ---------- - filename : Str + filename : str File name for this callback to write log to. """ def __init__(self, filename="ansor_tuning.json"): @@ -46,20 +46,20 @@ class LogReader(Object): Parameters ---------- - filename : Str + filename : str = "ansor_tuning.json" File name for this reader to load log from. """ def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) - def read_lines(self, max_size=-1, skip_size=0): + def read_lines(self, max_lines=-1, skip_lines=0): """ Read multiple lines from the log file. Parameters ---------- - max_size : Int - The maximum number of lines. -1 means read all lines. - skip_size : Int + max_lines : int = -1 + The maximum number of lines. -1 means to read all lines. + skip_lines : int = 0 Skip the first n lines. Returns @@ -69,8 +69,7 @@ def read_lines(self, max_size=-1, skip_size=0): results : List[MeasureResult] The MeasureResults loaded from the log file. """ - inputs, results = _ffi_api.LogReaderReadLines( - self, max_size, skip_size) + inputs, results = _ffi_api.LogReaderReadLines(self, max_lines, skip_lines) return inputs, results def __iter__(self): @@ -81,13 +80,13 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) -def load_from_file(filename: str): +def load_from_file(filename): """ Load measurement records from a file. Parameters ---------- - filename : Str + filename : str File name to load log from. Returns @@ -97,32 +96,35 @@ def load_from_file(filename: str): return zip(*LogReader(filename).read_lines()) -def write_measure_records_to_file(filename, inputs, results): +def append_measure_records_to_file(filename, inputs, results): """ - Write(append) measure records to file. + Aappend measure records to file. Parameters ---------- - filename : Str + filename : str File name to write log to. inputs: List[MeasureInputs] The target MeasureInputs to be written. results: List[MeasureResults] The target MeasureResults to be written. """ - _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) + _ffi_api.AppendMeasureRecordsToFile(filename, inputs, results) def best_measure_pair_in_file(filename, workload_key=None, target=None): - """ Return the best measurement pair form a log file + """ Return the best measurement pair form a log file. This may return none results if + there is no legal measure pair with the specified workload_key/target found from the log file. Parameters ---------- - filename : Str + filename : str File name to load log from. - workload_key : Str + workload_key : Optional[str] = None The workload key of the target compute declaration. - target : Str + With `None`, this retuns the best measure pair of all workloads. + target : Optional[tvm.target.Target] = None The target device. + With `None`, this retuns the best measure pair of all target devices. Returns ------- @@ -144,9 +146,7 @@ def best_measure_pair_in_file(filename, workload_key=None, target=None): if target and inp.task.target.target_name != target.target_name: continue - costs = [] - for value in res.costs: - costs.append(value.value) + costs = [v.value for v in res.costs] cost = np.mean(costs) if cost < best_cost: best_cost = cost diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 041327d147d5..309f63cec93c 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -17,6 +17,7 @@ """Common utilities for ansor""" +from typing import Hashable import multiprocessing import multiprocessing.pool import queue @@ -25,11 +26,12 @@ try: import psutil except ImportError: - psutil = None + raise ImportError("psutil not found, try `pip install psutil` to fix this") from tvm.tir import expr from tvm.tir.transform import Simplify from tvm.ir.transform import Sequential +from ..te import Tensor, placeholder def get_func_name(func): @@ -87,6 +89,42 @@ def get_const_tuple(in_tuple): return tuple(get_const_int(x) for x in in_tuple) + +def list_to_tuple(x): + """ Convert a list to a tuple recursively. """ + assert isinstance(x, list) + return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) + + +def serialize_args(args): + """ + Serialize arguments of a function to a hashable and jsonable tuple. + Currently this is mainly used for tvm.tensor.Tensor + """ + ret = [] + for t in args: + if isinstance(t, Tensor): + t = ('TENSOR', get_const_tuple(t.shape), t.dtype) + elif isinstance(t, list): + t = list_to_tuple(t) + + assert isinstance(t, Hashable), str(t) + " is not hashable" + ret.append(t) + + return tuple(ret) + + +def deserialize_args(args): + """The inverse function of :code:`serialize_args`""" + ret = [] + for t in args: + if isinstance(t, (tuple, list)) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + class NoDaemonProcess(multiprocessing.Process): @property def daemon(self): diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index d423c689bf99..405d2afeff3c 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -29,27 +29,25 @@ When we need the dag, we decode the string and call the function, which will return the dag. """ -from typing import Hashable import pickle import json -import hashlib import tvm._ffi -from ..te import Tensor, PlaceholderOp, ComputeOp, placeholder -from .utils import get_const_tuple +from .utils import serialize_args, deserialize_args from .compute_dag import ComputeDAG WORKLOAD_FUNC_REGISTRY = {} -def register_workload_func(func): - """Register a workload generation function +def register_workload(func): + """ Register a workload by generation function. + The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. Examples -------- - @register_workload_func + @register_workload def matmul(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') @@ -64,94 +62,6 @@ def matmul(N, M, K): return func -def compute_dag_hash(dag): - """ Get hash value for a ComputeDAG. - - Parameters - ---------- - dag : ComputeDAG - The target ComputeDAG. - - Returns - ------- - hash_value : Str - The hash value of this ComputeDAG in hex digest. - """ - # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG - str_key = '' - for op in dag.ops: - t = op.output(0) - if isinstance(op, PlaceholderOp): - str_key += 'placeholder,' - str_key += str(get_const_tuple(t.shape)) + ',' - str_key += t.dtype + ';' - elif isinstance(op, ComputeOp): - str_key += str(t.op.body) + ',' - str_key += str(get_const_tuple(t.shape)) + ',' - str_key += t.dtype + ';' - else: - raise ValueError("Invalid op: " + op) - - str_key = str_key.encode(encoding='utf-8') - return hashlib.md5(str_key).hexdigest() - - -def register_workload_bufs(bufs): - """ Directly register buffers of a workload and return the workload_key. - - The buffers can be looked up with workload_key_to_tensors by the workload_key. - - Parameters - ---------- - bufs : List[Tensor] - A list of Tensors for the target compute declaration. - - Returns - ------- - workload_key : Str - A workload key mapping to the registered compute declaration. - """ - dag = ComputeDAG(bufs) - key = compute_dag_hash(dag) - WORKLOAD_FUNC_REGISTRY[key] = bufs - return json.dumps((key,)) - - -def list_to_tuple(x): - """Convert a list to a tuple recursively""" - assert isinstance(x, list) - return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) - - -def serialize_args(args): - """ - Serialize arguments of a function to a hashable and jsonable tuple. - Currently this is mainly used for tvm.tensor.Tensor - """ - ret = [] - for t in args: - if isinstance(t, Tensor): - t = ('TENSOR', get_const_tuple(t.shape), t.dtype) - elif isinstance(t, list): - t = list_to_tuple(t) - - assert isinstance(t, Hashable), str(t) + " is not hashable" - ret.append(t) - - return tuple(ret) - - -def deserialize_args(args): - """The inverse function of :code:`serialize_args`""" - ret = [] - for t in args: - if isinstance(t, (tuple, list)) and t[0] == 'TENSOR': - ret.append(placeholder(shape=t[1], dtype=t[2])) - else: - ret.append(t) - return ret - - @tvm._ffi.register_func("ansor.workload_key_to_tensors") def workload_key_to_tensors(workload_key): """ Decode a workload key to the input/output tensors. @@ -170,10 +80,9 @@ def workload_key_to_tensors(workload_key): name = workload[0] lookup = WORKLOAD_FUNC_REGISTRY[name] - if callable(lookup): - args = deserialize_args(workload[1:]) - return lookup(*args) - return lookup + assert callable(lookup) + args = deserialize_args(workload[1:]) + return lookup(*args) @ tvm._ffi.register_func("ansor.workload_key_to_dag") @@ -194,7 +103,7 @@ def workload_key_to_dag(workload_key): return ComputeDAG(tensors) -def make_workload_key_func(func, args): +def make_workload_key_by_func(func, args): """ make a workload key from function and arguments. Parameters @@ -219,29 +128,11 @@ def make_workload_key_func(func, args): raise ValueError("Invalid function: " + str(func)) assert func_name in WORKLOAD_FUNC_REGISTRY, \ - "%s is not registered. Please register it with register_auto_scheduler_workload_func" % func + "%s is not registered. Please register it with @ansor.register_workload" % func return json.dumps((func_name,) + args) -def make_workload_key_bufs(bufs): - """ make a workload key from bufs. - - Parameters - ---------- - bufs : List[Tensor] - A list of Tensors for the target compute declaration. - - Returns - ------- - workload_key : Str - A workload key mapping to the registered compute declaration. - """ - dag = ComputeDAG(bufs) - key = compute_dag_hash(dag) - return json.dumps((key,)) - - def dump_workload_func_registry(filename): """ Dump workload function registry to a pickle binary file. diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 62c5cd2c4033..ff778c94249c 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -419,13 +419,14 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext").set_body_typed([](LogReader reade } }); -TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile").set_body([](TVMArgs args, TVMRetValue* ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); -}); +TVM_REGISTER_GLOBAL("ansor.AppendMeasureRecordsToFile") + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); + }); TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 62ebeb99a6c8..9288d88b6270 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -21,7 +21,7 @@ import topi -@ansor.register_workload_func +@ansor.register_workload def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index f8d41edd27dd..df8314686d7a 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -33,7 +33,7 @@ def test_serialization(): res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) with tempfile.NamedTemporaryFile() as fp: - ansor.serialization.write_measure_records_to_file(fp.name, [inp], [res]) + ansor.serialization.append_measure_records_to_file(fp.name, [inp], [res]) log_reader = ansor.serialization.LogReader(fp.name) inputs, results = log_reader.read_lines() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 20d93b8681e7..6245878696b2 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -34,7 +34,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 - workload_key = ansor.make_workload_key_func(matmul_ansor_test, (N, N, N)) + workload_key = ansor.make_workload_key_by_func(matmul_ansor_test, (N, N, N)) dag = ansor.workload_key_to_dag(workload_key) target = tvm.target.create(target) task = ansor.SearchTask(dag, workload_key, target) @@ -47,7 +47,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, verbose=0, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) - sch, args = ansor.auto_schedule(task, search_policy=search_policy, + sch, args = ansor.auto_schedule(task, target, search_policy=search_policy, tune_option=tune_option) inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) From 64f8f8d73d186229514f38b07c7b76771925397a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 1 Jul 2020 16:22:59 +0800 Subject: [PATCH 54/78] Update --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/auto_schedule.py | 17 ++--- python/tvm/ansor/compute_dag.py | 17 +++-- python/tvm/ansor/workload_registry.py | 71 +++++++------------ src/ansor/auto_schedule.cc | 27 +------ src/ansor/auto_schedule.h | 17 ----- src/ansor/compute_dag.cc | 19 ----- src/ansor/compute_dag.h | 4 -- src/ansor/serialization.cc | 18 +++-- tests/python/unittest/test_ansor_common.py | 2 +- tests/python/unittest/test_ansor_measure.py | 5 ++ .../unittest/test_ansor_search_policy.py | 4 +- 12 files changed, 71 insertions(+), 133 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 2368cfd8489a..480fb3422624 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -31,5 +31,4 @@ from .measure import MeasureInput, LocalBuilder, LocalRunner from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, append_measure_records_to_file -from .workload_registry import register_workload, \ - workload_key_to_dag, make_workload_key_by_func +from .workload_registry import register_workload_by_func, make_workload_key_by_func diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 8b1a2c14a5c3..12a74cc6cd5a 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -19,6 +19,7 @@ import tvm._ffi from tvm.runtime import Object +from .compute_dag import ComputeDAG from .measure import LocalBuilder, LocalRunner from . import _ffi_api @@ -179,13 +180,13 @@ def auto_schedule(task, target, target_host=None, search_policy='default', else: raise ValueError("Invalid search policy: " + search_policy) - tune_option = TuneOption() if tune_option is None else tune_option + tune_option = tune_option if tune_option else TuneOption() if isinstance(task, str): - sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( - task, target, target_host, search_policy, hardware_params, tune_option) - return sch, tensors - if isinstance(task, SearchTask): - sch, tensors = _ffi_api.AutoScheduleBySearchTask(task, search_policy, tune_option) - return sch, tensors - raise ValueError("Invalid task: " + task + ". Expect a string or SearchTask") + dag = ComputeDAG(task) + task = SearchTask(dag, task, target, target_host, hardware_params) + elif not isinstance(task, SearchTask): + raise ValueError("Invalid task: " + task + ". Expect a string or SearchTask") + + sch, tensors = _ffi_api.AutoSchedule(task, search_policy, tune_option) + return sch, tensors diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index dbac298a3f92..aa4626ed2153 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -25,6 +25,7 @@ from .loop_state import State, StateObject from .utils import get_const_tuple +from .workload_registry import workload_key_to_tensors from . import _ffi_api @@ -36,11 +37,19 @@ class ComputeDAG(Object): Parameters ---------- - tensors : List[Tensor] - `Tensor`s for a compute declaration. + compute : Union[List[Tensor], str] + `Tensor`s or workload key for a compute declaration. """ - def __init__(self, tensors): - self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) + def __init__(self, compute): + if isinstance(compute, str): + compute = workload_key_to_tensors(compute) + elif isinstance(compute, list): + for item in compute: + if not isinstance(item, tvm.te.Tensor): + raise ValueError("The input of ComputeDAG should be a list of Tensor") + else: + raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor") + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) def get_init_state(self): """ Get init state of this ComputeDAG. diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 405d2afeff3c..01adb4075e1b 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -34,12 +34,11 @@ import tvm._ffi from .utils import serialize_args, deserialize_args -from .compute_dag import ComputeDAG WORKLOAD_FUNC_REGISTRY = {} -def register_workload(func): +def register_workload_by_func(func): """ Register a workload by generation function. The input function should take hashable and jsonable arguments @@ -47,7 +46,7 @@ def register_workload(func): Examples -------- - @register_workload + @register_workload_by_func def matmul(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') @@ -62,47 +61,6 @@ def matmul(N, M, K): return func -@tvm._ffi.register_func("ansor.workload_key_to_tensors") -def workload_key_to_tensors(workload_key): - """ Decode a workload key to the input/output tensors. - - Parameters - ---------- - workload_key : Str - The target workload key. - - Returns - ------- - tensors : List[Tensor] - The registered compute declaration Tensors. - """ - workload = json.loads(workload_key) - name = workload[0] - lookup = WORKLOAD_FUNC_REGISTRY[name] - - assert callable(lookup) - args = deserialize_args(workload[1:]) - return lookup(*args) - - -@ tvm._ffi.register_func("ansor.workload_key_to_dag") -def workload_key_to_dag(workload_key): - """ Decode a workload key to a compute dag. - - Parameters - ---------- - workload_key : Str - The target workload key. - - Returns - ------- - dag : ComputeDAG - ComputeDAG to the registered compute declaration. - """ - tensors = workload_key_to_tensors(workload_key) - return ComputeDAG(tensors) - - def make_workload_key_by_func(func, args): """ make a workload key from function and arguments. @@ -128,11 +86,34 @@ def make_workload_key_by_func(func, args): raise ValueError("Invalid function: " + str(func)) assert func_name in WORKLOAD_FUNC_REGISTRY, \ - "%s is not registered. Please register it with @ansor.register_workload" % func + "%s is not registered. Please register it with @ansor.register_workload_by_func" % func return json.dumps((func_name,) + args) +@tvm._ffi.register_func("ansor.workload_key_to_tensors") +def workload_key_to_tensors(workload_key): + """ Decode a workload key to the input/output tensors. + + Parameters + ---------- + workload_key : Str + The target workload key. + + Returns + ------- + tensors : List[Tensor] + The registered compute declaration Tensors. + """ + workload = json.loads(workload_key) + name = workload[0] + lookup = WORKLOAD_FUNC_REGISTRY[name] + + assert callable(lookup) + args = deserialize_args(workload[1:]) + return lookup(*args) + + def dump_workload_func_registry(filename): """ Dump workload function registry to a pickle binary file. diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index c06080b5cb32..91b96ff1cbaf 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -62,19 +62,6 @@ std::pair > AutoSchedule(SearchTask task, return task->compute_dag.ApplySteps(state->transform_steps); } -std::pair > AutoSchedule(std::string workload_key, Target target, - Target target_host, - SearchPolicy search_policy, - HardwareParams hardware_params, - TuneOption tune_option) { - // Create SearchTask from the given workload key - ComputeDAG dag = ComputeDAG(workload_key); - SearchTask task = SearchTask(std::move(dag), std::move(workload_key), std::move(target), - std::move(target_host), std::move(hardware_params)); - // Search for the best schedule - return AutoSchedule(std::move(task), std::move(search_policy), std::move(tune_option)); -} - TVM_REGISTER_GLOBAL("ansor.TuneOption") .set_body_typed([](int n_trials, int early_stopping, int num_measure_per_round, int verbose, Builder builder, Runner runner, Array measure_callbacks, @@ -83,24 +70,12 @@ TVM_REGISTER_GLOBAL("ansor.TuneOption") measure_callbacks, pre_search_callbacks); }); -TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") +TVM_REGISTER_GLOBAL("ansor.AutoSchedule") .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { te::Schedule sch; Array return_tensors; std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); return Array{sch, return_tensors}; }); - -TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") - .set_body_typed([](std::string workload_key, Target target, Target target_host, - SearchPolicy search_policy, HardwareParams hardware_params, - TuneOption tune_option) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = AutoSchedule(workload_key, target, target_host, search_policy, - hardware_params, tune_option); - return Array{sch, return_tensors}; - }); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 4ec0b99887c3..1e6974bc1373 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -105,23 +105,6 @@ class TuneOption : public ObjectRef { std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, TuneOption tune_option); - -/*! - * \brief Auto schedule search for a given compute declaration, by workload key. - * \param workload_key The target workload key. - * \param target The target device of this schedule search. - * \param target_host The target host device of this schedule search. - * \param search_policy The search policy to be used for schedule search. - * \param hardware_params The hardware parameters of this schedule search. - * \param tune_option Tuning and measurement options. - * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. - */ -std::pair > AutoSchedule(std::string workload_key, Target target, - Target target_host, - SearchPolicy search_policy, - HardwareParams hardware_params, - TuneOption tune_option); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index ddcefbd81641..4a43b7727bec 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -243,25 +243,6 @@ ComputeDAG::ComputeDAG(Array tensors) { data_ = std::move(node); } -ComputeDAG::ComputeDAG(const std::string& workload_key) { - Array tens; - // Call python function to decode the workload_key and get the I/O tensors - if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { - tens = (*f)(workload_key); - } else { - LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; - } - auto node = make_object(); - FlopEstimator estimator; - node->tensors = std::move(tens); - std::vector ops; - TopoSortOps(node->tensors, &ops); - node->ops = Array(ops); - node->flop_ct = estimator.EstimateFlop(node->ops); - node->init_state = State(node->ops); - data_ = std::move(node); -} - State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } std::pair > ComputeDAG::ApplySteps( diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 3a5089aafb1a..92e300a0b71d 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -84,10 +84,6 @@ class ComputeDAG : public ObjectRef { * \param tensors `te::Tensor`s for a compute declaration. */ explicit ComputeDAG(Array tensors); - /*! \brief The constructor. - * \param workload_key Workload key for a compute declaration. - */ - explicit ComputeDAG(const std::string& workload_key); /*! * \brief Apply transform steps to the init state of this DAG, and get the diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index ff778c94249c..4d25f944fb72 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -453,13 +453,19 @@ TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") } else { auto find_res = task_cache.find(key); if (find_res == task_cache.end()) { - if (inp->task->compute_dag.defined()) { // the measure input is complete + if (inp->task->compute_dag.defined()) { ptask = inp->task.operator->(); - } else { // the measure input is incomplete - // rebuild task for incomplete measure pairs read from file - SearchTask new_task = - SearchTask(ComputeDAG(workload_key), workload_key, inp->task->target, - inp->task->target_host, inp->task->hardware_params); + } else { + // If the measure input is incomplete, rebuild task for it + Array tens; + // Call python function to decode the workload_key and get the I/O tensors + if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { + tens = (*f)(workload_key); + } else { + LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; + } + SearchTask new_task = SearchTask(ComputeDAG(tens), workload_key, inp->task->target, + inp->task->target_host, inp->task->hardware_params); task_cache.insert(std::make_pair(key, new_task)); ptask = new_task.operator->(); } diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 9288d88b6270..773ca8e4f13e 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -21,7 +21,7 @@ import topi -@ansor.register_workload +@ansor.register_workload_by_func def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index df8314686d7a..a21f70f0d956 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -26,6 +26,9 @@ def test_serialization(): dag, s = get_tiled_matmul() + + if not tvm.runtime.enabled("llvm"): + return target = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", target) @@ -49,6 +52,8 @@ def test_serialization(): def test_measure_local_builder_runner(): dag, s0 = get_tiled_matmul() + if not tvm.runtime.enabled("llvm"): + return tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 6245878696b2..ed900aa211ba 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -35,7 +35,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 workload_key = ansor.make_workload_key_by_func(matmul_ansor_test, (N, N, N)) - dag = ansor.workload_key_to_dag(workload_key) + dag = ansor.ComputeDAG(workload_key) target = tvm.target.create(target) task = ansor.SearchTask(dag, workload_key, target) @@ -74,6 +74,8 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' def test_search_basic(): + if not tvm.runtime.enabled("llvm"): + return # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) From 1b16dd480a9689586f1afcfa74e7e284bbdfb8c8 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 1 Jul 2020 17:58:01 +0800 Subject: [PATCH 55/78] Update std::vector to Array; Update verbosity setting; Some commemts addressed --- python/tvm/ansor/auto_schedule.py | 40 ++++--- python/tvm/ansor/loop_state.py | 30 ++---- python/tvm/ansor/measure.py | 14 +-- python/tvm/ansor/workload_registry.py | 44 ++++++-- src/ansor/auto_schedule.cc | 25 +++-- src/ansor/auto_schedule.h | 19 ++-- src/ansor/compute_dag.cc | 23 ++-- src/ansor/compute_dag.h | 15 ++- src/ansor/loop_state.cc | 102 ++++++------------ src/ansor/loop_state.h | 45 ++++---- src/ansor/measure.cc | 5 +- src/ansor/measure.h | 8 +- src/ansor/search_policy/empty_policy.cc | 18 +++- src/ansor/search_policy/empty_policy.h | 4 +- src/ansor/search_policy/search_policy.h | 20 ++-- src/ansor/serialization.cc | 66 ++++++------ src/ansor/transform_step.cc | 43 ++++---- src/ansor/transform_step.h | 32 +++--- src/ansor/utils.h | 53 ++++++--- .../unittest/test_ansor_search_policy.py | 4 +- 20 files changed, 328 insertions(+), 282 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 12a74cc6cd5a..6379e99bc41a 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. -"""User interface for auto-scheduler""" +""" +User interface for Ansor auto-scheduler. + +The basic schedule search process for Ansor is design to be: +`Program sampling` -> `Performance Tuning`. + +In `Program sampling`, we use some predefined or heuristic rules to generate several initial +schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model +and evolutionary search to seek for schedules with the best performance. Candidate schedules will +be measured in the target hardware. +""" import tvm._ffi from tvm.runtime import Object @@ -29,6 +39,9 @@ class HardwareParams(Object): """ The parameters of target hardware, this is used to guide the search process of SearchPolicy. + TODO(...): This is considering to merge with the new Target: + https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 + Parameters ---------- num_cores : int @@ -89,25 +102,26 @@ def __init__(self): @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): - """ The options for tuning. + """ This controls the options of performance tuning. Parameters ---------- - n_trials: int = 1 + num_measure_trials: int = 0 The number of total schedule measure trials. - Ansor takes `n_trials` state for measuring in total, and finally gets the best schedule - among them. - With `n_trials` == 1, Ansor will do the schedule search but don't involve measurement, - this can be used if we want to quickly get a runnable schedule without performance tuning. + Ansor takes `num_measure_trials` state for measuring in total, and finally gets the best + schedule among them. + With `num_measure_trials` == 0, Ansor will do the schedule search but don't involve + measurement, this can be used if we want to quickly get a runnable schedule without + performance tuning. early_stopping: int = -1 Stops early the tuning if no improvement get after n measurements. - num_measure_per_round: int = 64 + num_measures_per_round: int = 64 The number of programs to be measured at each search round. The whole schedule search process is designed to have several rounds to try a total - `n_trials` schedules. - We have: `num_search_rounds` = `n_trials` // `num_measure_per_round` + `num_measure_trials` schedules. + We have: `num_search_rounds` = `num_measure_trials` // `num_measures_per_round` verbose: int = 1 - Verbosity level. 0 means silent. + Verbosity level. 0 for silent, 1 to output information during schedule search. builder: Union[Builder, str] = 'local' Builder which builds the program. runner: Union[Runner, str] = 'local' @@ -123,7 +137,7 @@ class TuneOption(Object): - ansor.PreloadCustomSketchRule TODO(jcf94): Add these implementation in later PRs. """ - def __init__(self, n_trials=1, early_stopping=-1, num_measure_per_round=64, + def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_round=64, verbose=1, builder='local', runner='local', measure_callbacks=None, pre_search_callbacks=None): if isinstance(builder, str): @@ -142,7 +156,7 @@ def __init__(self, n_trials=1, early_stopping=-1, num_measure_per_round=64, pre_search_callbacks = [] if pre_search_callbacks is None else pre_search_callbacks self.__init_handle_by_constructor__( - _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_round, + _ffi_api.TuneOption, num_measure_trials, early_stopping, num_measures_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 7ed32e477523..a1420bf9b30e 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -49,17 +49,6 @@ class Iterator(Object): class Stage(Object): """A stage in the compute declaration. Similar to tvm.te.schedule.Stage""" - @property - def iters(self): - """ - Returns - ------- - iters : List[Iterator] - """ - if not hasattr(self, "iterators_cache"): - setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self)) - return getattr(self, "iterators_cache") - @tvm._ffi.register_object("ansor.State") class StateObject(Object): @@ -102,7 +91,7 @@ def stages(self): stages : List[Stage] """ if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) + self.stages_cache = self.state_object.stages return self.stages_cache @property @@ -113,14 +102,9 @@ def stage_ops(self): ops: List[Operation] """ if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) + self.stages_cache = self.state_object.stages return [stage.op for stage in self.stages_cache] - def transform_steps_size(self): - """ Return the size of current transform_steps - """ - return _ffi_api.StateGetTransformStepsSize(self.state_object) - def reorder(self, stage, order): """ Schedule primitive corresponds to te.reorder. @@ -171,7 +155,7 @@ def fuse(self, stage, iters): Parameters ---------- stage : Union[int, Operation, Tensor] - The target Stage to be reordered, can be a Stage order index, Stage operation or stage + The target Stage to be fused, can be a Stage order index, Stage operation or stage output tensor. iters : List[Iterator] The iterators to be fused @@ -200,11 +184,11 @@ def _resolve_stage_id(self, stage_id): return self.stage_id_map[stage_id.op] if isinstance(stage_id, int): return stage_id - raise ValueError("Invalid stage_id") + raise ValueError("Invalid stage_id: " + stage_id + ". Expect a int, Operation or Tensor") def _update_stage_id_map(self): if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) + self.stages_cache = self.state_object.stages for index, stage in enumerate(self.stages_cache): self.stage_id_map[stage.op] = index @@ -213,12 +197,12 @@ def _clear_cache(self): def __getitem__(self, key): if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) + self.stages_cache = self.state_object.stages if isinstance(key, Tensor): key = key.op if isinstance(key, Operation): return self.stages_cache[self.stage_id_map[key]] - raise ValueError("Item must be Tensor or Operation") + raise ValueError("Invalid item: " + key + ". Expect a Operation or Tensor") def __str__(self): return str(self.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 5bb0a58f37aa..691dcca6f85c 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -127,7 +127,7 @@ def build(self, measure_inputs, verbose=1): measure_inputs : List[MeasureInput] A List of MeasureInput. verbost : int = 1 - Verbosity level. (0 means silent) + Verbosity level. 0 for silent, 1 to output information during program building. Returns ------- @@ -150,7 +150,7 @@ def run(self, measure_inputs, build_results, verbose=1): build_results : List[BuildResult] A List of BuildResult to be ran. verbost : int = 1 - Verbosity level. (0 means silent) + Verbosity level. 0 for silent, 1 to output information during program running. Returns ------- @@ -283,7 +283,7 @@ def timed_func(): else: filename = "" - if verbose >= 1: + if verbose == 1: if error_no == MeasureErrorNo.NO_ERROR: print(".", end="") else: @@ -292,7 +292,7 @@ def timed_func(): res = call_func_with_timeout(timeout, timed_func) if isinstance(res, TimeoutError): - if verbose >= 1: + if verbose == 1: print(".T", end="") # Build timeout res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout @@ -356,7 +356,7 @@ def timed_func(inp, build_res): toc = time.time() time.sleep(cooldown_interval) - if verbose >= 1: + if verbose == 1: if error_no == MeasureErrorNo.NO_ERROR: print("*", end="") else: @@ -374,13 +374,13 @@ def timed_func(inp, build_res): res = call_func_with_timeout( timeout, timed_func, args=(inp, build_res)) if isinstance(res, TimeoutError): - if verbose >= 1: + if verbose == 1: print("*T", end="") # Run timeout res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \ build_res.time_cost + timeout, time.time() measure_results.append(MeasureResult(*res)) - if verbose >= 1: + if verbose == 1: print("") return measure_results diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 01adb4075e1b..3dae9d15a9d3 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -46,7 +46,7 @@ def register_workload_by_func(func): Examples -------- - @register_workload_by_func + @ansor.register_workload_by_func def matmul(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') @@ -85,19 +85,44 @@ def make_workload_key_by_func(func, args): else: raise ValueError("Invalid function: " + str(func)) - assert func_name in WORKLOAD_FUNC_REGISTRY, \ - "%s is not registered. Please register it with @ansor.register_workload_by_func" % func + if not func_name in WORKLOAD_FUNC_REGISTRY: + raise ValueError("%s is not registered. " % func, + "Please register it with @ansor.register_workload_by_func") return json.dumps((func_name,) + args) +def decode_workload_key_to_func_args(workload_key): + """ Decode a workload key to the registerd function name and its corresponding args. + + Parameters + ---------- + workload_key : str + The target workload key. + + Returns + ------- + name : str + The function name of this workload key. + args : List[Tensor] + The args of the generation function. + """ + workload = json.loads(workload_key) + if not workload[0] in WORKLOAD_FUNC_REGISTRY: + raise ValueError("%s is not registered. " % workload[0] + + "Please register it with @ansor.register_workload_by_func") + return workload[0], deserialize_args(workload[1:]) + + @tvm._ffi.register_func("ansor.workload_key_to_tensors") def workload_key_to_tensors(workload_key): - """ Decode a workload key to the input/output tensors. + """ Get the input/output tensors from the workload key. + + This method is usually used to create a ComputeDAG by workload key. Parameters ---------- - workload_key : Str + workload_key : str The target workload key. Returns @@ -105,12 +130,9 @@ def workload_key_to_tensors(workload_key): tensors : List[Tensor] The registered compute declaration Tensors. """ - workload = json.loads(workload_key) - name = workload[0] + name, args = decode_workload_key_to_func_args(workload_key) lookup = WORKLOAD_FUNC_REGISTRY[name] - assert callable(lookup) - args = deserialize_args(workload[1:]) return lookup(*args) @@ -119,7 +141,7 @@ def dump_workload_func_registry(filename): Parameters ---------- - filename : Str + filename : str The filename to dump workload function registry to. """ global WORKLOAD_FUNC_REGISTRY @@ -132,7 +154,7 @@ def load_workload_func_registry(filename): Parameters ---------- - filename : Str + filename : str The filename to load workload function registry from. """ global WORKLOAD_FUNC_REGISTRY diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 91b96ff1cbaf..4dff43fd0b2f 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -34,13 +34,14 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(TuneOptionNode); -TuneOption::TuneOption(int n_trials, int early_stopping, int num_measure_per_round, int verbose, - Builder builder, Runner runner, Array measure_callbacks, +TuneOption::TuneOption(int num_measure_trials, int early_stopping, int num_measures_per_round, + int verbose, Builder builder, Runner runner, + Array measure_callbacks, Array pre_search_callbacks) { auto node = make_object(); - node->n_trials = n_trials; + node->num_measure_trials = num_measure_trials; node->early_stopping = early_stopping; - node->num_measure_per_round = num_measure_per_round; + node->num_measures_per_round = num_measures_per_round; node->verbose = verbose; node->builder = std::move(builder); node->runner = std::move(runner); @@ -56,18 +57,20 @@ std::pair > AutoSchedule(SearchTask task, ProgramMeasurer measurer = ProgramMeasurer(tune_option->builder, tune_option->runner, tune_option->measure_callbacks, tune_option->verbose); // Search for the best schedule - State state = search_policy->Search(task, tune_option->n_trials, tune_option->early_stopping, - tune_option->num_measure_per_round, tune_option->verbose, - measurer, tune_option->pre_search_callbacks); + State state = + search_policy->Search(task, tune_option->num_measure_trials, tune_option->early_stopping, + tune_option->num_measures_per_round, tune_option->verbose, measurer, + tune_option->pre_search_callbacks); return task->compute_dag.ApplySteps(state->transform_steps); } TVM_REGISTER_GLOBAL("ansor.TuneOption") - .set_body_typed([](int n_trials, int early_stopping, int num_measure_per_round, int verbose, - Builder builder, Runner runner, Array measure_callbacks, + .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, + int verbose, Builder builder, Runner runner, + Array measure_callbacks, Array pre_search_callbacks) { - return TuneOption(n_trials, early_stopping, num_measure_per_round, verbose, builder, runner, - measure_callbacks, pre_search_callbacks); + return TuneOption(num_measure_trials, early_stopping, num_measures_per_round, verbose, + builder, runner, measure_callbacks, pre_search_callbacks); }); TVM_REGISTER_GLOBAL("ansor.AutoSchedule") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 1e6974bc1373..fd65efd4c4af 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -40,12 +40,12 @@ namespace ansor { class TuneOptionNode : public Object { public: /*! \brief Number of total measurement trials. */ - int n_trials; + int num_measure_trials; /*! \brief Stops early the tuning if no improvement after n measurements. */ int early_stopping; /*! \brief The number of programs to be measured at each search round. */ - int num_measure_per_round; - /*! \brief Verbosity level. (0 means silent) */ + int num_measures_per_round; + /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */ int verbose; /*! \brief Builder which builds the program */ Builder builder; @@ -57,9 +57,9 @@ class TuneOptionNode : public Object { Array pre_search_callbacks; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("n_trials", &n_trials); + v->Visit("num_measure_trials", &num_measure_trials); v->Visit("early_stopping", &early_stopping); - v->Visit("num_measure_per_round", &num_measure_per_round); + v->Visit("num_measures_per_round", &num_measures_per_round); v->Visit("verbose", &verbose); v->Visit("builder", &builder); v->Visit("runner", &runner); @@ -79,16 +79,17 @@ class TuneOption : public ObjectRef { public: /*! * \brief The constructor - * \param n_trials Number of total measurement trials. + * \param num_measure_trials Number of total measurement trials. * \param early_stopping Stops early the tuning if no improvement after n measurements. - * \param num_measure_per_round The number of programs to be measured at each search round. - * \param verbose Verbosity level. (0 means silent) + * \param num_measures_per_round The number of programs to be measured at each search round. + * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule + * search. * \param builder Builder which builds the program. * \param runner Runner which runs the program and measure time costs. * \param measure_callbacks MeasureCallback functions to be called after each measure batch. * \param pre_search_callbacks SearchCallback functions to be called before schedule search. */ - TuneOption(int n_trials, int early_stopping, int num_measure_per_round, int verbose, + TuneOption(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks); diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 4a43b7727bec..1edd600cdd02 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -69,7 +69,7 @@ void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops -void TopoSortOps(const Array& tensors, std::vector* ops) { +void TopoSortOps(const Array& tensors, Array* ops) { std::unordered_map degree; std::unordered_map > edge_set; std::unordered_map priority; @@ -234,10 +234,10 @@ class FlopEstimator : public ExprFunctor { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; + Array ops; node->tensors = std::move(tensors); - std::vector ops; TopoSortOps(node->tensors, &ops); - node->ops = Array(ops); + node->ops = std::move(ops); node->flop_ct = estimator.EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); @@ -246,13 +246,13 @@ ComputeDAG::ComputeDAG(Array tensors) { State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } std::pair > ComputeDAG::ApplySteps( - const std::vector& transform_steps) const { + const Array& transform_steps) const { std::vector stages; StageToAxesMap stage_to_axes; return ReplaySteps(transform_steps, &stages, &stage_to_axes); } -std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_steps) const { +std::string ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { std::vector stages; StageToAxesMap stage_to_axes; Array ops; @@ -286,15 +286,16 @@ std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_st << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } + std::vector step_vector(transform_steps.begin(), transform_steps.end()); // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); + ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, step_vector); } return ss.str(); } -State ComputeDAG::ReplayAndInferBound(const std::vector& transform_steps) const { +State ComputeDAG::ReplayAndInferBound(const Array& transform_steps) const { State ret_state = GetInitState(); StateNode* pstate = ret_state.CopyOnWrite(); pstate->transform_steps = transform_steps; @@ -359,7 +360,7 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { continue; } - std::vector new_iters; + Array new_iters; new_iters.reserve(stage->iters.size()); for (size_t j = 0; j < stage->iters.size(); ++j) { const Iterator& iter = stage->iters[j]; @@ -374,13 +375,13 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } } - pstate->stages[i] = - Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); + pstate->stages.Set( + i, Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs)); } } std::pair > ComputeDAG::ReplaySteps( - const std::vector& transform_steps, std::vector* stages, + const Array& transform_steps, std::vector* stages, StageToAxesMap* stage_to_axes) const { std::vector ops; for (const auto& op : operator->()->ops) { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 92e300a0b71d..a8bace4d3f23 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -35,15 +35,13 @@ #include #include +#include "transform_step.h" + namespace tvm { namespace ansor { class StateNode; class State; -class Step; - -typedef std::unordered_map, ObjectHash, ObjectEqual> - StageToAxesMap; /*! * \brief Update stage and axes mapping during replay. @@ -91,14 +89,13 @@ class ComputeDAG : public ObjectRef { * \param transform_steps Transform steps of the target state. * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. */ - std::pair > ApplySteps( - const std::vector& transform_steps) const; + std::pair > ApplySteps(const Array& transform_steps) const; /*! * \brief Print transform steps as equivalent python schedule API. * \param transform_steps Transform steps of the target state. * \return Python schedule code. */ - std::string PrintStepsAsPython(const std::vector& transform_steps) const; + std::string PrintStepsAsPython(const Array& transform_steps) const; /*! * \brief Replay the transform steps and call ir_pass::InferBound to fill correct bound @@ -112,7 +109,7 @@ class ComputeDAG : public ObjectRef { * \param transform_steps Transform steps of the target state. * \return The State after inferbound. */ - State ReplayAndInferBound(const std::vector& transform_steps) const; + State ReplayAndInferBound(const Array& transform_steps) const; /*! * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. * \param state The target state. @@ -144,7 +141,7 @@ class ComputeDAG : public ObjectRef { * \param stage_to_axes A pointer to StageToAxesMap. * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. */ - std::pair > ReplaySteps(const std::vector& transform_steps, + std::pair > ReplaySteps(const Array& transform_steps, std::vector* stages, StageToAxesMap* stage_to_axes) const; diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 46daf85c6b08..e69339b3ecd3 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -81,7 +81,7 @@ Stage::Stage(te::Operation op) { data_ = std::move(node); } -Stage::Stage(te::Operation op, StageType op_type, const std::vector& iters, +Stage::Stage(te::Operation op, StageType op_type, const Array& iters, ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); @@ -92,8 +92,8 @@ Stage::Stage(te::Operation op, StageType op_type, const std::vector& i data_ = std::move(node); } -Stage::Stage(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, StageAttributes attrs) { +Stage::Stage(te::Operation op, StageType op_type, Array&& iters, ComputeAtType compute_at, + StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -113,29 +113,20 @@ State::State(const Array& ops) { data_ = std::move(node); } -State::State(const std::vector& stages, const std::vector& transform_steps, - bool complete) { - auto node = make_object(); - node->stages = stages; - node->transform_steps = transform_steps; - node->complete = complete; - data_ = std::move(node); -} - /********** Schedule primitives apis for state **********/ -void State::reorder(int stage_id, const std::vector& order) { +void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " << "should be specified"; - std::vector after_ids; + Array after_ids; GetIndices(stage->iters, order, &after_ids); ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); DoReorderStep(step); } -std::vector State::split(int stage_id, const Iterator& it, - const std::vector& lengths, bool inner_to_outer) { +Array State::split(int stage_id, const Iterator& it, const Array& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStep(stage_id, GetIndex(stage->iters, it), @@ -144,9 +135,9 @@ std::vector State::split(int stage_id, const Iterator& it, return DoSplitStep(step); } -Iterator State::fuse(int stage_id, const std::vector& iters) { +Iterator State::fuse(int stage_id, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; - std::vector indices; + Array indices; GetIndices(stage->iters, iters, &indices); FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); @@ -156,19 +147,18 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { /********** Step implementations for state **********/ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; - std::vector iters; + Array iters; for (auto x : step->after_ids) { - iters.push_back(stage->iters[x]); + iters.push_back(stage->iters[x->value]); } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = - Stage(stage->op, stage->op_type, std::move(iters), stage->compute_at, stage->attrs); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), + stage->compute_at, stage->attrs)); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -std::vector State::DoSplitStepCommon(int stage_id, int iter_id, - const std::vector& lengths, - bool inner_to_outer) { +Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; @@ -180,7 +170,7 @@ std::vector State::DoSplitStepCommon(int stage_id, int iter_id, tosplit_min = tosplit_extent = PrimExpr(); } - std::vector outs; + Array outs; for (size_t i = 0; i < lengths.size(); ++i) { PrimExpr l; std::string name; @@ -209,25 +199,27 @@ std::vector State::DoSplitStepCommon(int stage_id, int iter_id, } if (inner_to_outer) { outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone)); - std::reverse(outs.begin(), outs.end()); + // Reverse the Iterator array + Array temp(std::move(outs.rbegin()), std::move(outs.rend())); + outs = std::move(temp); } else { outs.push_back( Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_type, kNone)); } - std::vector new_iters; + Array new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = - Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs)); return outs; } -std::vector State::DoSplitStep(const SplitStep& step) { +Array State::DoSplitStep(const SplitStep& step) { return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer); } @@ -242,10 +234,10 @@ Iterator State::DoFuseStep(const FuseStep& step) { std::vector ori_iters; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { - CHECK_EQ(step->fused_ids[i], step->fused_ids[i - 1] + 1); + CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); } - const Iterator& it = stage->iters[step->fused_ids[i]]; + const Iterator& it = stage->iters[step->fused_ids[i]->value]; ori_iters.push_back(it); new_name += it->name + "@"; @@ -269,21 +261,21 @@ Iterator State::DoFuseStep(const FuseStep& step) { range = Range::FromMinExtent(0, new_extent); } Iterator new_it = Iterator(new_name, range, new_iter_type, kNone, &ori_iters); - std::vector new_iters; + Array new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front()); + stage->iters.begin() + step->fused_ids.front()->value); new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back()->value + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = - Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs)); return new_it; } -void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { +void State::DoSteps(const Array& steps, const ComputeDAG& dag) { // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); double complete_rate = -1.0; @@ -436,46 +428,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /********** State interface API for ffi **********/ -TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { - return Array(stage->iters); -}); - -TVM_REGISTER_GLOBAL("ansor.StateGetStages").set_body_typed([](const State& state) { - return Array(state->stages); -}); - -TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const State& state) { - return static_cast(state->transform_steps.size()); -}); - TVM_REGISTER_GLOBAL("ansor.StateReorder") .set_body_typed([](State state, int stage_id, const Array& order) { - std::vector ord; - for (const auto& i : order) { - ord.push_back(i); - } - state.reorder(stage_id, ord); + state.reorder(stage_id, order); return state; }); TVM_REGISTER_GLOBAL("ansor.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, const Array& lengths, bool inner_to_outer) { - std::vector len; - for (const auto& i : lengths) { - len.push_back(i); - } - const auto& res = state.split(stage_id, it, len, inner_to_outer); - return Array{state, Array(res)}; + const auto& res = state.split(stage_id, it, lengths, inner_to_outer); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { - std::vector its; - for (const auto& i : iters) { - its.push_back(i); - } - const auto& res = state.fuse(stage_id, its); + const auto& res = state.fuse(stage_id, iters); return Array{state, res}; }); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 9154f3b32c3d..c53a52380b69 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -179,13 +179,16 @@ class StageNode : public Object { /*! \brief The type of this stage. */ StageType op_type; /*! \brief The iterators in this stage. */ - std::vector iters; + Array iters; /*! \brief The compute location of this stage. */ ComputeAtType compute_at; /*! \brief Other stage-level attributes. */ StageAttributes attrs; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); } + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("iters", &iters); + } static constexpr const char* _type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); @@ -210,8 +213,8 @@ class Stage : public ObjectRef { * \param compute_at The compute at type of this op. * \param attrs Other stage-level attributes. */ - Stage(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, StageAttributes attrs); + Stage(te::Operation op, StageType op_type, const Array& iters, ComputeAtType compute_at, + StageAttributes attrs); /*! * \brief The constructor. * \param op A `te::Operation`. @@ -220,8 +223,8 @@ class Stage : public ObjectRef { * \param compute_at The compute at type of this op. * \param attrs Other stage-level attributes. */ - Stage(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, StageAttributes attrs); + Stage(te::Operation op, StageType op_type, Array&& iters, ComputeAtType compute_at, + StageAttributes attrs); TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); @@ -235,9 +238,9 @@ class Stage : public ObjectRef { class StateNode : public Object { public: /*! \brief Current stages and loop structures. */ - std::vector stages; + Array stages; /*! \brief History transformation steps. */ - std::vector transform_steps; + Array transform_steps; /*! \brief Indicate whether this state has unfilled tile sizes. */ bool complete; /*! @@ -248,6 +251,8 @@ class StateNode : public Object { ComputeDAG task_dag; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("stages", &stages); + v->Visit("transform_steps", &transform_steps); v->Visit("complete", &complete); v->Visit("task_dag", &task_dag); } @@ -267,20 +272,13 @@ class State : public ObjectRef { * \param ops `te::Operation`s for a compute declaration. */ explicit State(const Array& ops); - /*! - * \brief The constructor. - * \param stages Stages of the target state. - * \param transform_steps Transform steps of the target state. - * \param complete Indicate whether this state has unfilled tile sizes. - */ - State(const std::vector& stages, const std::vector& transform_steps, bool complete); /*! * \brief Schedule primitive corresponds to te.reorder. * \param stage_id The index of the target stage. * \param order The target iterator order. */ - void reorder(int stage_id, const std::vector& order); + void reorder(int stage_id, const Array& order); /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the target stage. @@ -289,22 +287,22 @@ class State : public ObjectRef { * \param inner_to_outer True for split from inner to outer & False for outer to inner. * \return The iterator results after split. */ - std::vector split(int stage_id, const Iterator& it, - const std::vector& lengths, bool inner_to_outer = true); + Array split(int stage_id, const Iterator& it, const Array& lengths, + bool inner_to_outer = true); /*! * \brief Schedule primitive corresponds to te.fuse. * \param stage_id The index of the target stage. * \param iters The target iterators to be fused. * \return The iterator result after fuse. */ - Iterator fuse(int stage_id, const std::vector& iters); + Iterator fuse(int stage_id, const Array& iters); /*! * \brief General do step functions with a runtime dynamic dispatcher. * \param steps The target transform steps. * \param dag The target ComputeDAG. */ - void DoSteps(const std::vector& steps, const ComputeDAG& dag); + void DoSteps(const Array& steps, const ComputeDAG& dag); /*! * \brief Print the state to a string. @@ -332,7 +330,7 @@ class State : public ObjectRef { * \param step A SplitStep. * \return The iterator results after split. */ - std::vector DoSplitStep(const SplitStep& step); + Array DoSplitStep(const SplitStep& step); /*! * \brief Apply fuse step to current state. * \param step A FuseStep. @@ -348,9 +346,8 @@ class State : public ObjectRef { * \param inner_to_outer The split direction. * \return The iterator results after split. */ - std::vector DoSplitStepCommon(int stage_id, int iter_id, - const std::vector& lengths, - bool inner_to_outer); + Array DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, + bool inner_to_outer); }; } // namespace ansor diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 08c66cf72d36..ad33302b0fd8 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -210,11 +210,10 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po } ct++; - StdCout(verbose) << std::fixed << std::setprecision(2) - << "===============================================\n" + StdCout(verbose) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n" << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n" - << "===============================================\n" + << Chars('=', 50) << "\n" << input_batch[j]->state << "\n"; } diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 4955176aef2c..34663a72d09e 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -233,7 +233,7 @@ class BuilderNode : public Object { /*! * \brief Build programs and return results. * \param inputs An Array of MeasureInput. - * \param verbose Verbosity level. (0 means silent) + * \param verbose Verbosity level. 0 for silent, 1 to output information during program building. * \return An Array of MeasureResult. */ virtual Array Build(const Array& inputs, int verbose) = 0; @@ -261,7 +261,7 @@ class RunnerNode : public Object { * \brief Run measurement and return results. * \param inputs An Array of MeasureInput. * \param build_results An Array of BuildResult. - * \param verbose Verbosity level. (0 means silent) + * \param verbose Verbosity level. 0 for silent, 1 to output information during program running. * \return An Array of MeasureResult. */ virtual Array Run(const Array& inputs, @@ -370,7 +370,7 @@ class ProgramMeasurerNode : public Object { Runner runner; /*! \brief MeasureCallback to be called after each measure batch. */ Array callbacks; - /*! \brief Verbose level. */ + /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */ int verbose; /*! \brief The number of max continuous error. */ int max_continous_error; @@ -417,7 +417,7 @@ class ProgramMeasurer : public ObjectRef { * \param builder The Builder to build each program. * \param runner The Runner to measure each program. * \param callbacks MeasureCallback to be called after each measure batch. - * \param verbose Verbose level. + * \param verbose Verbosity level. 0 for silent, 1 to output information during program measuring. * \param max_continous_error The number of max continuous error. */ ProgramMeasurer(Builder builder, Runner runner, Array callbacks, int verbose, diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index d4ebc829f7a8..51a506e39eac 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -33,8 +33,8 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); -State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, - int num_measure_per_round, int verbose, ProgramMeasurer measurer, +State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { cur_task = task; @@ -44,8 +44,8 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, // Basic design principe: `SearchOneRound()` several times to get candidate states, // measure them and return the best one - // Measure is disabled if n_trials <= 1 - if (n_trials <= 1) { + // Measure is disabled if num_measure_trials <= 1 + if (num_measure_trials <= 1) { const auto& res = SearchOneRound(); CHECK_GT(res.size(), 0); @@ -58,7 +58,7 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, int ct = 0; // In each round, we call SearchOneRound to get several candidate states, // then use ProgramMeasurer to test their performance - while (ct < n_trials) { + while (ct < num_measure_trials) { const auto& res = SearchOneRound(); ct += res.size(); // Build MeasureInputs for measuring @@ -80,7 +80,15 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, // As an example policy, EmptyPolicy always returns a init state std::vector EmptyPolicyNode::SearchOneRound() { std::vector res; + + // 1. We will process `Program sampling` first to generate several initial schedules res.push_back(cur_task->compute_dag.GetInitState()); + + // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule + // with best performance + // Note: This example policy does not include this part + + // 3. The returned candidate schedules will be measured in hardware return res; } diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index 01a47a9d4120..c55dd1b2e272 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -42,8 +42,8 @@ namespace ansor { */ class EmptyPolicyNode : public SearchPolicyNode { public: - State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_round, - int verbose, ProgramMeasurer measurer, + State Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) final; static constexpr const char* _type_key = "ansor.EmptyPolicy"; diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f507aa98e22c..d9b1df96709e 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -22,6 +22,14 @@ * \brief The base class for search policy, including the abstract defination of search policy and * some other supporting structures. * + * The basic schedule search process for Ansor is design to be: + * `Program sampling` -> `Performance Tuning`. + * + * In `Program sampling`, we use some predefined or heuristic rules to generate several initial + * schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model + * and evolutionary search to seek for schedules with the best performance. Candidate schedules + * will be measured in the target hardware. + * * \note Adding a new search policy. * In design, there's no need for users to implement their own search policy, our formal search * policy(will be brought later) should be enough to cover auto schedule generation for different @@ -92,7 +100,7 @@ class SearchPolicyNode : public Object { SearchTask cur_task; /*! * \brief Verbose level to control the screen output during schedule search. - * (0 means silent) + * 0 for silent, 1 to output information. */ int verbose; @@ -105,16 +113,16 @@ class SearchPolicyNode : public Object { * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state * get during the search process. * \param task The target search task. - * \param n_trials Total schedules to be tried during this search. + * \param num_measure_trials Total schedules to be tried during this search. * \param early_stopping Early stop if no better schedule is found. - * \param num_measure_per_round Max measure batch in one search round. - * \param verbose Verbose level. (0 means silent) + * \param num_measures_per_round Max measure batch in one search round. + * \param verbose Verbose level. 0 for silent, 1 to output information during schedule search. * \param measurer A ProgramMeasurer which packs Builder & Runner inside. * \param pre_search_callbacks SearchCallback to be called before schedule search. * \return The best state get. */ - virtual State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_round, - int verbose, ProgramMeasurer measurer, + virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) = 0; /*! diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 4d25f944fb72..358d170e00df 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -54,13 +54,24 @@ inline std::vector& IntArrayToVector(std::vector* out, return *out; } +inline std::vector& IntArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::IntImm>& data) { + out->clear(); + for (const auto& x : data) { + CHECK(x.defined()); + out->push_back(x->value); + } + return *out; +} + template <> -struct Handler> { - inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Stage>& data) { +struct Handler<::tvm::Array<::tvm::ansor::Stage>> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::Array<::tvm::ansor::Stage>& data) { writer->BeginArray(false); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Stage>* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::ansor::Stage>* data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); @@ -69,8 +80,8 @@ struct Handler> { }; template <> -struct Handler> { - inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Step>& data) { +struct Handler<::tvm::Array<::tvm::ansor::Step>> { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::ansor::Step>& data) { std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { @@ -79,34 +90,18 @@ struct Handler> { if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { writer->WriteArrayItem(std::string("RE")); writer->WriteArrayItem(ps->stage_id); - - writer->WriteArraySeperator(); - writer->BeginArray(false); - for (int x : ps->after_ids) { - writer->WriteArrayItem(x); - } - writer->EndArray(); + writer->WriteArrayItem(IntArrayToVector(&tmp, ps->after_ids)); } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); - if (ps->extent.defined()) { - writer->WriteArrayItem(::tvm::ansor::GetIntImm(ps->extent)); - } else { - writer->WriteArrayItem(0); - } + writer->WriteArrayItem(ps->extent.defined() ? ::tvm::ansor::GetIntImm(ps->extent) : 0); writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); - - writer->WriteArraySeperator(); - writer->BeginArray(false); - for (int x : ps->fused_ids) { - writer->WriteArrayItem(x); - } - writer->EndArray(); + writer->WriteArrayItem(IntArrayToVector(&tmp, ps->fused_ids)); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -115,7 +110,7 @@ struct Handler> { writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Step>* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::ansor::Step>* data) { std::vector int_list; bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; @@ -135,7 +130,11 @@ struct Handler> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::ReorderStep(stage_id, int_list)); + ::tvm::Array<::tvm::IntImm> after_ids; + for (const auto& i : int_list) { + after_ids.push_back(::tvm::IntImm(::tvm::DataType::Int(32), i)); + } + data->push_back(::tvm::ansor::ReorderStep(stage_id, after_ids)); } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); @@ -152,9 +151,12 @@ struct Handler> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&inner_to_outer); - data->push_back(::tvm::ansor::SplitStep( - stage_id, iter_id, extent, - std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); + ::tvm::Array<::tvm::PrimExpr> lengths; + for (const auto& i : int_list) { + lengths.push_back(::tvm::PrimExpr(i)); + } + data->push_back( + ::tvm::ansor::SplitStep(stage_id, iter_id, extent, lengths, inner_to_outer)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); @@ -162,7 +164,11 @@ struct Handler> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); + ::tvm::Array<::tvm::IntImm> fused_ids; + for (const auto& i : int_list) { + fused_ids.push_back(::tvm::IntImm(::tvm::DataType::Int(32), i)); + } + data->push_back(::tvm::ansor::FuseStep(stage_id, fused_ids)); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 10ef9e3d6ab0..a6de46b5f631 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -37,7 +37,7 @@ namespace tvm { namespace ansor { /********** Reorder **********/ -ReorderStep::ReorderStep(int stage_id, const std::vector& after_ids) { +ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); node->stage_id = stage_id; node->after_ids = after_ids; @@ -53,7 +53,7 @@ void ReorderStepNode::ApplyToSchedule(std::vector* stages, std::vector new_axes; new_axes.reserve(axes.size()); for (auto i : after_ids) { - new_axes.push_back(axes[i]); + new_axes.push_back(axes[i->value]); } stage.reorder(new_axes); (*stage_to_axes)[stage] = std::move(new_axes); @@ -67,7 +67,7 @@ std::string ReorderStepNode::PrintAsPythonAPI(std::vector* stages, ss << "s[" << CleanName(stage->op->name) << "].reorder("; for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + ss << CleanName((*stage_to_axes)[stage][after_ids[i]->value]->var->name_hint); if (i != after_ids.size() - 1) { ss << ", "; } @@ -79,14 +79,13 @@ std::string ReorderStepNode::PrintAsPythonAPI(std::vector* stages, } /********** Split **********/ -std::vector ApplySplitToSchedule(std::vector* stages, - StageToAxesMap* stage_to_axes, int stage_id, int iter_id, - const std::vector& lengths, - bool inner_to_outer) { +Array ApplySplitToSchedule(std::vector* stages, StageToAxesMap* stage_to_axes, + int stage_id, int iter_id, const Array& lengths, + bool inner_to_outer) { te::Stage& stage = (*stages)[stage_id]; const std::vector& axes = (*stage_to_axes)[stage]; - std::vector outs; + Array outs; if (inner_to_outer) { IterVar outer = axes[iter_id], inner; for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { @@ -108,9 +107,13 @@ std::vector ApplySplitToSchedule(std::vector* stages, std::vector new_axes; new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); if (inner_to_outer) { - new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); + for (auto x = outs.rbegin(); x != outs.rend(); ++x) { + new_axes.push_back((*x)); + } } else { - new_axes.insert(new_axes.end(), outs.begin(), outs.end()); + for (const auto& x : outs) { + new_axes.push_back(x); + } } new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); (*stage_to_axes)[stage] = std::move(new_axes); @@ -119,7 +122,7 @@ std::vector ApplySplitToSchedule(std::vector* stages, } std::string PrintSplitAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, - int stage_id, int iter_id, const std::vector& lengths, + int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { te::Stage& stage = (*stages)[stage_id]; auto to_split = (*stage_to_axes)[stage][iter_id]; @@ -148,8 +151,8 @@ std::string PrintSplitAsPythonAPI(std::vector* stages, StageToAxesMap return ss.str(); } -SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, bool inner_to_outer) { +SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, + bool inner_to_outer) { auto node = make_object(); node->stage_id = stage_id; // Extent can be a unreducible expression in some special cases @@ -162,8 +165,8 @@ SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, data_ = std::move(node); } -std::vector SplitStepNode::ApplyToSchedule(std::vector* stages, - StageToAxesMap* stage_to_axes) const { +Array SplitStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes) const { return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } @@ -174,7 +177,7 @@ std::string SplitStepNode::PrintAsPythonAPI(std::vector* stages, } /********** Fuse **********/ -FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { +FuseStep::FuseStep(int stage_id, const Array& fused_ids) { auto node = make_object(); node->stage_id = stage_id; node->fused_ids = fused_ids; @@ -188,14 +191,14 @@ IterVar FuseStepNode::ApplyToSchedule(std::vector* stages, Array to_fuse; for (auto i : fused_ids) { - to_fuse.push_back(axes[i]); + to_fuse.push_back(axes[i->value]); } IterVar fused_axis; stage.fuse(to_fuse, &fused_axis); std::vector new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()->value); new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back()->value + 1, axes.end()); (*stage_to_axes)[stage] = std::move(new_axes); return fused_axis; @@ -208,7 +211,7 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector* stages, std::stringstream to_fuse; for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); + to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]->value]->var->name_hint); if (i != fused_ids.size() - 1) { to_fuse << ", "; } diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 3536024f46eb..b2c1a5896aa7 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -98,7 +98,7 @@ class ReorderStepNode : public StepNode { * \brief The iterator ids after reorder. * This array should specify the order of all iterators. */ - std::vector after_ids; + Array after_ids; /*! * \brief Apply the current state to tvm.schedule @@ -126,7 +126,7 @@ class ReorderStep : public Step { * \param stage_id The index of the target stage. * \param after_ids The index of the iterators after reorder. */ - ReorderStep(int stage_id, const std::vector& after_ids); + ReorderStep(int stage_id, const Array& after_ids); TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; @@ -142,7 +142,7 @@ class SplitStepNode : public StepNode { /*! \brief The extent length of the axis to split. */ PrimExpr extent; /*! \brief The split factors. */ - std::vector lengths; + Array lengths; /*! * \brief If true, the `lengths` denote the lengths of iterators * from inner level to outer level @@ -155,8 +155,8 @@ class SplitStepNode : public StepNode { * \param stage_to_axes A pointer to StageToAxesMap. * \return The iterator results after split. */ - std::vector ApplyToSchedule(std::vector* stages, - StageToAxesMap* stage_to_axes) const; + Array ApplyToSchedule(std::vector* stages, + StageToAxesMap* stage_to_axes) const; std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule, @@ -179,7 +179,7 @@ class SplitStep : public Step { * \param lengths The extent length of the axis to split. * \param inner_to_outer The split direction. */ - SplitStep(int stage_id, int iter_id, PrimExpr extent, const std::vector& lengths, + SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, bool inner_to_outer); TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); @@ -189,7 +189,7 @@ class SplitStep : public Step { class FuseStepNode : public StepNode { public: /*! \brief The ids of iterators to fuse. */ - std::vector fused_ids; + Array fused_ids; /*! * \brief Apply the current state to tvm.schedule @@ -218,7 +218,7 @@ class FuseStep : public Step { * \param stage_id The index of the target stage. * \param fused_ids The index of the target iterators to be fused. */ - FuseStep(int stage_id, const std::vector& fused_ids); + FuseStep(int stage_id, const Array& fused_ids); TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; @@ -235,8 +235,12 @@ struct hash<::tvm::ansor::Step> { std::size_t operator()(const ::tvm::ansor::Step& step) const { // clang-format off if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { - return ::dmlc::HashCombine(1, - ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->after_ids)); + size_t ret = ::dmlc::HashCombine(1, std::hash()(ps->stage_id)); + for (const auto& x : ps->after_ids) { + CHECK(x.defined()); + ret = ::dmlc::HashCombine(ret, x->value); + } + return ret; } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { size_t ret = ::dmlc::HashCombine(2, ::dmlc::HashCombine(std::hash()(ps->stage_id), @@ -252,8 +256,12 @@ struct hash<::tvm::ansor::Step> { } return ret; } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine(3, - ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->fused_ids)); + size_t ret = ::dmlc::HashCombine(3, std::hash()(ps->stage_id)); + for (const auto& x : ps->fused_ids) { + CHECK(x.defined()); + ret = ::dmlc::HashCombine(ret, x->value); + } + return ret; } else { LOG(FATAL) << "Invalid step"; } diff --git a/src/ansor/utils.h b/src/ansor/utils.h index b43edfc10527..7d76828a24cb 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -83,10 +83,9 @@ namespace tvm { namespace ansor { /********** Utilities for std::vector, std::set, std::string **********/ -/*! \brief Get the first appearance index of elements in a vector */ -template -inline void GetIndices(const std::vector& array, const std::vector& to_locate, - std::vector* indices) { +/*! \brief Get the first appearance index of elements in a array type object */ +template +inline void GetIndices(const ArrayT0& array, const ArrayT1& to_locate, std::vector* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); if (it != array.end()) { @@ -97,6 +96,19 @@ inline void GetIndices(const std::vector& array, const std::vector& to_loc } } +/*! \brief Get the first appearance index of elements in a array type object */ +template +inline void GetIndices(const ArrayT0& array, const ArrayT1& to_locate, Array* indices) { + for (const auto& v : to_locate) { + auto it = std::find(array.begin(), array.end(), v); + if (it != array.end()) { + indices->push_back(IntImm(tvm::DataType::Int(32), it - array.begin())); + } else { + LOG(FATAL) << "Cannot find the item"; + } + } +} + /*! \brief Get the first appearance index of an element in a vector */ template inline int GetIndex(const std::vector& array, const T& to_locate) { @@ -109,6 +121,18 @@ inline int GetIndex(const std::vector& array, const T& to_locate) { return -1; } +/*! \brief Get the first appearance index of an element in a vector */ +template +inline int GetIndex(const ArrayT& array, const T& to_locate) { + for (size_t i = 0; i < array.size(); ++i) { + if (array[i] == to_locate) { + return i; + } + } + LOG(FATAL) << "Cannot find the item"; + return -1; +} + /*! \brief Delete an element in a vector */ template inline void DeleteItem(std::vector* array, const T& to_delete) { @@ -193,20 +217,23 @@ NullStream& operator<<(NullStream& os, const T& value) { /*! \brief Get std cout with verbose control */ inline std::ostream& StdCout(int verbose) { - if (verbose >= 1) { - return std::cout; - } else { - return NullStream::Global(); + return verbose == 1 ? std::cout : NullStream::Global(); +} + +/*! \brief Print multiple chars */ +inline std::string Chars(const char& str, int times) { + std::stringstream ret; + for (int i = 0; i < times; ++i) { + ret << str; } + return ret.str(); } /*! \brief Print a title */ inline void PrintTitle(const std::string& title, int verbose) { - if (verbose >= 1) { - std::cout << "------------------------------------------------------------\n"; - std::cout << "----------------------- [ " << title << " ]\n"; - std::cout << "------------------------------------------------------------" << std::endl; - } + StdCout(verbose) << Chars('-', 60) << "\n" + << Chars('-', 25) << " [ " << title << " ]\n" + << Chars('-', 60) << std::endl; } /*! \brief A simple thread pool */ diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index ed900aa211ba..0e4a70d840d0 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -28,7 +28,7 @@ from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', - cost_model=None, n_trials=2, params=None, + cost_model=None, num_measure_trials=2, params=None, pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) @@ -44,7 +44,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.EmptyPolicy() # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) - tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, verbose=0, + tune_option = ansor.TuneOption(num_measure_trials=num_measure_trials, runner=runner, verbose=0, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) sch, args = ansor.auto_schedule(task, target, search_policy=search_policy, From 9fa897bc156a18f805ed3a315b26a2fab5db1ac8 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 2 Jul 2020 14:28:13 +0800 Subject: [PATCH 56/78] std::vector->Array & std::string->String --- src/ansor/auto_schedule.cc | 3 - src/ansor/auto_schedule.h | 1 - src/ansor/compute_dag.cc | 15 ++-- src/ansor/compute_dag.h | 7 +- src/ansor/loop_state.cc | 36 +++++----- src/ansor/loop_state.h | 17 +++-- src/ansor/measure.cc | 39 +++++------ src/ansor/measure.h | 28 ++++---- src/ansor/search_policy/empty_policy.cc | 10 +-- src/ansor/search_policy/empty_policy.h | 6 +- src/ansor/search_policy/search_policy.h | 4 +- src/ansor/search_task.cc | 6 +- src/ansor/search_task.h | 6 +- src/ansor/serialization.cc | 93 ++++--------------------- src/ansor/serialization.h | 8 +-- src/ansor/transform_step.cc | 87 +++++++++++++---------- src/ansor/transform_step.h | 54 +++++++------- src/ansor/utils.h | 48 ++----------- 18 files changed, 181 insertions(+), 287 deletions(-) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 4dff43fd0b2f..2b0860a07303 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -26,9 +26,6 @@ #include -#include -#include - namespace tvm { namespace ansor { diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index fd65efd4c4af..130b56f8a54e 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -27,7 +27,6 @@ #ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ #define TVM_ANSOR_AUTO_SCHEDULE_H_ -#include #include #include "measure.h" diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 1edd600cdd02..48e8a8149905 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include @@ -52,7 +51,7 @@ TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Update stage to axis mapping void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { if (auto pop = stage->op.as()) { - std::vector& axes = (*stage_to_axes)[stage]; + Array& axes = (*stage_to_axes)[stage]; axes.clear(); for (const auto& axis : pop->axis) { axes.push_back(axis); @@ -247,13 +246,13 @@ State ComputeDAG::GetInitState() const { return Downcast(operator->()->in std::pair > ComputeDAG::ApplySteps( const Array& transform_steps) const { - std::vector stages; + Array stages; StageToAxesMap stage_to_axes; return ReplaySteps(transform_steps, &stages, &stage_to_axes); } -std::string ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { - std::vector stages; +String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { + Array stages; StageToAxesMap stage_to_axes; Array ops; for (const auto& op : operator->()->ops) { @@ -340,7 +339,7 @@ void ComputeDAG::InferBound(std::vector* states) const { } void ComputeDAG::InferBoundCommon(StateNode* pstate) const { - std::vector stages; + Array stages; StageToAxesMap stage_to_axes; te::Schedule sch; Array tensors; @@ -381,9 +380,9 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } std::pair > ComputeDAG::ReplaySteps( - const Array& transform_steps, std::vector* stages, + const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes) const { - std::vector ops; + Array ops; for (const auto& op : operator->()->ops) { if (!op->IsInstance()) { ops.push_back(op); diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index a8bace4d3f23..5c8cb649cd1e 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -27,11 +27,8 @@ #ifndef TVM_ANSOR_COMPUTE_DAG_H_ #define TVM_ANSOR_COMPUTE_DAG_H_ -#include #include -#include -#include #include #include @@ -95,7 +92,7 @@ class ComputeDAG : public ObjectRef { * \param transform_steps Transform steps of the target state. * \return Python schedule code. */ - std::string PrintStepsAsPython(const Array& transform_steps) const; + String PrintStepsAsPython(const Array& transform_steps) const; /*! * \brief Replay the transform steps and call ir_pass::InferBound to fill correct bound @@ -142,7 +139,7 @@ class ComputeDAG : public ObjectRef { * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. */ std::pair > ReplaySteps(const Array& transform_steps, - std::vector* stages, + Array* stages, StageToAxesMap* stage_to_axes) const; /*! diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index e69339b3ecd3..f084cd62a749 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -25,10 +25,11 @@ #include "loop_state.h" -#include #include #include +#include + #include "transform_step.h" #include "utils.h" @@ -41,9 +42,8 @@ TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); /********** Iterator **********/ -Iterator::Iterator(std::string name, Range range, IteratorType iter_type, - IteratorAnnotation annotation, const std::vector* ori_iters, - std::string attr) { +Iterator::Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters, String attr) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); @@ -118,7 +118,7 @@ void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " << "should be specified"; - Array after_ids; + Array after_ids; GetIndices(stage->iters, order, &after_ids); ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); @@ -137,7 +137,7 @@ Array State::split(int stage_id, const Iterator& it, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; - Array indices; + Array indices; GetIndices(stage->iters, iters, &indices); FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); @@ -149,7 +149,7 @@ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Array iters; for (auto x : step->after_ids) { - iters.push_back(stage->iters[x->value]); + iters.push_back(stage->iters[x.as()->value]); } StateNode* pstate = CopyOnWrite(); pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), @@ -173,7 +173,7 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array< Array outs; for (size_t i = 0; i < lengths.size(); ++i) { PrimExpr l; - std::string name; + String name; if (inner_to_outer) { l = lengths[lengths.size() - i - 1]; name = it->name + "." + std::to_string(lengths.size() - i); @@ -200,7 +200,7 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array< if (inner_to_outer) { outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone)); // Reverse the Iterator array - Array temp(std::move(outs.rbegin()), std::move(outs.rend())); + Array temp(outs.rbegin(), outs.rend()); outs = std::move(temp); } else { outs.push_back( @@ -227,19 +227,20 @@ Iterator State::DoFuseStep(const FuseStep& step) { int stage_id = step->stage_id; const Stage& stage = operator->()->stages[stage_id]; - std::string new_name; + String new_name; PrimExpr new_extent = 1; IteratorType new_iter_type = kSpecial; std::vector ori_iters; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { - CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); + CHECK_EQ(step->fused_ids[i].as()->value, + step->fused_ids[i - 1].as()->value + 1); } - const Iterator& it = stage->iters[step->fused_ids[i]->value]; + const Iterator& it = stage->iters[step->fused_ids[i].as()->value]; ori_iters.push_back(it); - new_name += it->name + "@"; + new_name = new_name + it->name + "@"; if (it->range.defined() && new_extent.defined()) { new_extent = new_extent * it->range->extent; @@ -263,9 +264,10 @@ Iterator State::DoFuseStep(const FuseStep& step) { Iterator new_it = Iterator(new_name, range, new_iter_type, kNone, &ori_iters); Array new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front()->value); + stage->iters.begin() + step->fused_ids.front().as()->value); new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back()->value + 1, + new_iters.insert(new_iters.end(), + stage->iters.begin() + step->fused_ids.back().as()->value + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); @@ -384,7 +386,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b // Print state to ostream void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { // Gather placeholders - std::vector placeholders; + std::vector placeholders; for (const auto& stage : node->stages) { if (stage->op_type == kPlaceholder) { placeholders.push_back(stage->op->name); @@ -415,7 +417,7 @@ void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loo } } -std::string State::ToStr(bool delete_trivial_loop) const { +String State::ToStr(bool delete_trivial_loop) const { std::ostringstream os; PrintState(&os, operator->(), delete_trivial_loop); return os.str(); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index c53a52380b69..819a2c37fade 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -39,10 +39,9 @@ #ifndef TVM_ANSOR_LOOP_STATE_H_ #define TVM_ANSOR_LOOP_STATE_H_ +#include + #include -#include -#include -#include #include #include "compute_dag.h" @@ -117,7 +116,7 @@ class Iterator; class IteratorNode : public Object { public: /*! \brief The name of this iterator. */ - std::string name; + String name; /*! \brief The target range of this iterator. */ Range range; /*! \brief The iterator type of this iterator. */ @@ -127,7 +126,7 @@ class IteratorNode : public Object { /*! \brief The original iterators before fusion. */ std::vector ori_iters; /*! \brief The extra attributes of this iterator. */ - std::string attr; + String attr; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); @@ -154,8 +153,8 @@ class Iterator : public ObjectRef { * \param ori_iters The original iterators before fusion. * \param attr The extra attribute of this iterator. */ - Iterator(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr, std::string attr = ""); + Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr, String attr = ""); TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); }; @@ -310,7 +309,7 @@ class State : public ObjectRef { * (undefined or extent == 1, default set to True) * \return The human readable state structure. */ - std::string ToStr(bool delete_trivial_loop = true) const; + String ToStr(bool delete_trivial_loop = true) const; TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); @@ -360,7 +359,7 @@ namespace std { template <> struct hash<::tvm::ansor::State> { std::size_t operator()(const ::tvm::ansor::State& state) const { - return std::hash()(state.ToStr()); + return tvm::runtime::ObjectHash()(state.ToStr()); } }; diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index ad33302b0fd8..f03d1a1f957b 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -24,13 +24,9 @@ #include "measure.h" -#include #include #include -#include -#include -#include #include "utils.h" @@ -73,8 +69,8 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult::BuildResult(std::string filename, Array args, int error_no, - std::string error_msg, double time_cost) { +BuildResult::BuildResult(String filename, Array args, int error_no, String error_msg, + double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); @@ -84,8 +80,8 @@ BuildResult::BuildResult(std::string filename, Array args, int error data_ = std::move(node); } -MeasureResult::MeasureResult(Array costs, int error_no, std::string error_msg, - double all_cost, double timestamp) { +MeasureResult::MeasureResult(Array costs, int error_no, String error_msg, double all_cost, + double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; @@ -106,7 +102,7 @@ MeasureResult MeasureResultNode::copy() const { } /********** LocalBuilder **********/ -LocalBuilder::LocalBuilder(int timeout, int n_parallel, const std::string& build_func) { +LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; @@ -170,8 +166,8 @@ void ProgramMeasurerNode::Reset() { } void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy, - const std::vector& inputs, - std::vector* results, int batch_size) { + const Array& inputs, Array* results, + int batch_size) { results->clear(); results->reserve(inputs.size()); @@ -184,9 +180,9 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { - std::vector input_batch(inputs.begin() + i, - inputs.begin() + std::min(i + batch_size, inputs.size())); - std::vector result_batch; + Array input_batch(inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); + Array result_batch; // build and run SilentMeasure(task, input_batch, &result_batch); @@ -202,7 +198,7 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po error_ct++; } - const std::string& workload_key = input_batch[j]->task->workload_key; + const String& workload_key = input_batch[j]->task->workload_key; if (flops > best_flops[workload_key]) { best_flops[workload_key] = flops; best_state[workload_key] = input_batch[j]->state; @@ -233,9 +229,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po } } -void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, - const std::vector& inputs, - std::vector* results) { +void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array& inputs, + Array* results) { // Close the thread pool to avoid the conflits with python environment ThreadPool::Global().Abort(); @@ -300,13 +295,13 @@ TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, Sta }); TVM_REGISTER_GLOBAL("ansor.BuildResult") - .set_body_typed([](std::string filename, Array args, int error_no, - std::string error_msg, double time_cost) { + .set_body_typed([](String filename, Array args, int error_no, String error_msg, + double time_cost) { return BuildResult(filename, args, error_no, error_msg, time_cost); }); TVM_REGISTER_GLOBAL("ansor.MeasureResult") - .set_body_typed([](Array costs, int error_no, std::string error_msg, double all_cost, + .set_body_typed([](Array costs, int error_no, String error_msg, double all_cost, double timestamp) { return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); }); @@ -322,7 +317,7 @@ TVM_REGISTER_GLOBAL("ansor.RunnerRun") int verbose) { return runner->Run(inputs, build_results, verbose); }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") - .set_body_typed([](int timeout, int n_parallel, const std::string& build_func) { + .set_body_typed([](int timeout, int n_parallel, const String& build_func) { return LocalBuilder(timeout, n_parallel, build_func); }); diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 34663a72d09e..d552d688e12c 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -26,10 +26,8 @@ #ifndef TVM_ANSOR_MEASURE_H_ #define TVM_ANSOR_MEASURE_H_ -#include #include #include -#include #include "loop_state.h" #include "search_task.h" @@ -105,13 +103,13 @@ class MeasureInput : public ObjectRef { class BuildResultNode : public Object { public: /*! \brief The filename of built binary file. */ - std::string filename; + String filename; /*! \brief The arguments. */ Array args; /*! \brief The error code. (0 means no error, see MeasureErrorNO) */ int error_no; /*! \brief The error message if there is any error. */ - std::string error_msg; + String error_msg; /*! \brief The time cost of build. */ double time_cost; @@ -141,7 +139,7 @@ class BuildResult : public ObjectRef { * \param error_msg The error message if there is any error. * \param time_cost The time cost of build. */ - BuildResult(std::string filename, Array args, int error_no, std::string error_msg, + BuildResult(String filename, Array args, int error_no, String error_msg, double time_cost); TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); }; @@ -154,7 +152,7 @@ class MeasureResultNode : public Object { /*! \brief The error code. (0 means no error, see MeasureErrorNO) */ int error_no; /*! \brief The error message if there is any error. */ - std::string error_msg; + String error_msg; /*! \brief The time cost of build and run. */ double all_cost; /*! \brief The time stamps of this measurement. */ @@ -189,7 +187,7 @@ class MeasureResult : public ObjectRef { * \param all_cost The time cost of build and run. * \param timestamp The time stamps of this measurement. */ - MeasureResult(Array costs, int error_no, std::string error_msg, double all_cost, + MeasureResult(Array costs, int error_no, String error_msg, double all_cost, double timestamp); TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode); @@ -286,7 +284,7 @@ class Runner : public ObjectRef { class LocalBuilderNode : public BuilderNode { public: /*! \brief Build function. */ - std::string build_func; + String build_func; Array Build(const Array& inputs, int verbose) final; @@ -306,7 +304,7 @@ class LocalBuilder : public Builder { * \param n_parallel Number of threads used to build in parallel. * \param build_func The name of registered build function. */ - LocalBuilder(int timeout, int n_parallel, const std::string& build_func); + LocalBuilder(int timeout, int n_parallel, const String& build_func); TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); }; @@ -359,11 +357,11 @@ class ProgramMeasurerNode : public Object { /*! \brief Continuous error counter. */ int error_ct; /*! \brief Workload key to best flops map. */ - std::unordered_map best_flops; + std::unordered_map best_flops; /*! \brief Workload key to best state map. */ - std::unordered_map best_state; + std::unordered_map best_state; /*! \brief Workload key to best state's count index map. */ - std::unordered_map best_ct; + std::unordered_map best_ct; /*! \brief The Builder to build each program. */ Builder builder; /*! \brief The Runner to measure each program. */ @@ -387,7 +385,7 @@ class ProgramMeasurerNode : public Object { * \param batch_size Number of programs to be measured in one batch. */ void Measure(const SearchTask& task, const SearchPolicy& policy, - const std::vector& inputs, std::vector* results, + const Array& inputs, Array* results, int batch_size = -1); /*! * \brief Do measurement silently. @@ -396,8 +394,8 @@ class ProgramMeasurerNode : public Object { * \param inputs The target MeasureInputs. * \param results A pointer to MeasureResult vector, this is used as output. */ - void SilentMeasure(const SearchTask& task, const std::vector& inputs, - std::vector* results); + void SilentMeasure(const SearchTask& task, const Array& inputs, + Array* results); /*! \brief The default max continuous error setting. */ static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index 51a506e39eac..53cafd6524f3 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -51,8 +51,8 @@ State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early return res[0]; } else { - std::vector inputs; - std::vector results; + Array inputs; + Array results; measurer->Reset(); int ct = 0; @@ -66,7 +66,7 @@ State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early for (const auto& state : res) { // The class members measured_states_set_ provided by SearchPolicy can be used to filter // out the already measured states - inputs.emplace_back(cur_task, state); + inputs.push_back(MeasureInput(cur_task, state)); } // ProgramMeasurer will record the state with best performance during measure process measurer->Measure(cur_task, GetRef(this), inputs, &results); @@ -78,8 +78,8 @@ State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early } // As an example policy, EmptyPolicy always returns a init state -std::vector EmptyPolicyNode::SearchOneRound() { - std::vector res; +Array EmptyPolicyNode::SearchOneRound() { + Array res; // 1. We will process `Program sampling` first to generate several initial schedules res.push_back(cur_task->compute_dag.GetInitState()); diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index c55dd1b2e272..a8fd4fb424e9 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -25,9 +25,7 @@ #ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ -#include -#include - +#include "../loop_state.h" #include "search_policy.h" namespace tvm { @@ -54,7 +52,7 @@ class EmptyPolicyNode : public SearchPolicyNode { * \brief Use a sub function to generate several candidate states in each search round. * \returns Several generated states */ - std::vector SearchOneRound(); + Array SearchOneRound(); }; /*! diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index d9b1df96709e..aee93283aae6 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -52,9 +52,7 @@ #include -#include #include -#include #include #include "../search_task.h" @@ -139,7 +137,7 @@ class SearchPolicyNode : public Object { * \brief The set of already measured states. * We store the string format for redundancy check. */ - std::unordered_set measured_states_set_; + std::unordered_set measured_states_set_; /*! \brief The array of already measured states. */ std::vector measured_states_vector_; /*! \brief The throughputs of already measured states */ diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index c64b919e008f..7e6cb9d903d2 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -24,11 +24,9 @@ #include "search_task.h" -#include #include #include -#include #include namespace tvm { @@ -58,7 +56,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target return HardwareParams(); } -SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, Target target, +SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); @@ -82,7 +80,7 @@ TVM_REGISTER_GLOBAL("ansor.HardwareParams") }); TVM_REGISTER_GLOBAL("ansor.SearchTask") - .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, Target target, + .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, HardwareParams hardware_params) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); }); diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 351a51124e7e..9c92da5e387b 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -27,8 +27,6 @@ #include -#include - #include "compute_dag.h" namespace tvm { @@ -115,7 +113,7 @@ class SearchTaskNode : public Object { /*! \brief The ComputeDAG for target compute declaration. */ ComputeDAG compute_dag; /*! \brief The workload key for target compute declaration. */ - std::string workload_key; + String workload_key; /*! \brief The target device of this search task. */ Target target; /*! \brief The target host device of this search task. */ @@ -149,7 +147,7 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. */ - SearchTask(ComputeDAG compute_dag, std::string workload_key, Target target, Target target_host, + SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, HardwareParams hardware_params); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 358d170e00df..bf6dc2b2875d 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -54,16 +54,6 @@ inline std::vector& IntArrayToVector(std::vector* out, return *out; } -inline std::vector& IntArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::IntImm>& data) { - out->clear(); - for (const auto& x : data) { - CHECK(x.defined()); - out->push_back(x->value); - } - return *out; -} - template <> struct Handler<::tvm::Array<::tvm::ansor::Stage>> { inline static void Write(dmlc::JSONWriter* writer, @@ -130,9 +120,9 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - ::tvm::Array<::tvm::IntImm> after_ids; + ::tvm::Array<::tvm::PrimExpr> after_ids; for (const auto& i : int_list) { - after_ids.push_back(::tvm::IntImm(::tvm::DataType::Int(32), i)); + after_ids.push_back(i); } data->push_back(::tvm::ansor::ReorderStep(stage_id, after_ids)); } else if (name == "SP") { @@ -164,9 +154,9 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - ::tvm::Array<::tvm::IntImm> fused_ids; + ::tvm::Array<::tvm::PrimExpr> fused_ids; for (const auto& i : int_list) { - fused_ids.push_back(::tvm::IntImm(::tvm::DataType::Int(32), i)); + fused_ids.push_back(i); } data->push_back(::tvm::ansor::FuseStep(stage_id, fused_ids)); } else { @@ -204,7 +194,7 @@ template <> struct Handler<::tvm::ansor::SearchTaskNode> { inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::SearchTaskNode& data) { writer->BeginArray(false); - writer->WriteArrayItem(data.workload_key); + writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); writer->EndArray(); } @@ -215,7 +205,8 @@ struct Handler<::tvm::ansor::SearchTaskNode> { reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); - reader->Read(&data->workload_key); + reader->Read(&target_str); + data->workload_key = std::move(target_str); s = reader->NextArrayItem(); CHECK(s); reader->Read(&target_str); @@ -308,7 +299,7 @@ TVM_REGISTER_OBJECT_TYPE(LogReaderNode); const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) -LogToFile::LogToFile(std::string filename) { +LogToFile::LogToFile(String filename) { auto node = make_object(); node->filename = std::move(filename); data_ = std::move(node); @@ -353,7 +344,7 @@ void LogToFileNode::Callback(const SearchPolicy& policy, const Array(); node->filename = filename; node->infile.open(filename, std::ifstream::in); @@ -401,11 +392,11 @@ std::pair, Array> LogReaderNode::ReadLines(in return std::make_pair(inputs, results); } -TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const std::string& filename) { +TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const String& filename) { return LogToFile(filename); }); -TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const std::string& filename) { +TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const String& filename) { return LogReader(filename); }); @@ -426,69 +417,9 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext").set_body_typed([](LogReader reade }); TVM_REGISTER_GLOBAL("ansor.AppendMeasureRecordsToFile") - .set_body([](TVMArgs args, TVMRetValue* ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; + .set_body_typed([](String filename, Array in, Array res) { std::ofstream ofs(filename, std::ofstream::app); WriteMeasureRecords(&ofs, in, res); }); - -TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Array inputs = args[0]; - SearchTask external_task; - - if (args.size() > 1) { - external_task = args[1]; - } - - Array states; - states.reserve(inputs.size()); - - // (workload_key, target) -> (search_task) - std::unordered_map, SearchTask> task_cache; - - for (const auto& inp : inputs) { - const std::string& workload_key = inp->task->workload_key; - std::pair key(workload_key, inp->task->target->str()); - - const SearchTaskNode* ptask; - if (external_task.defined()) { - ptask = external_task.operator->(); - } else { - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - if (inp->task->compute_dag.defined()) { - ptask = inp->task.operator->(); - } else { - // If the measure input is incomplete, rebuild task for it - Array tens; - // Call python function to decode the workload_key and get the I/O tensors - if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { - tens = (*f)(workload_key); - } else { - LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; - } - SearchTask new_task = SearchTask(ComputeDAG(tens), workload_key, inp->task->target, - inp->task->target_host, inp->task->hardware_params); - task_cache.insert(std::make_pair(key, new_task)); - ptask = new_task.operator->(); - } - } else { - ptask = find_res->second.operator->(); - } - } - - State tmp_s = ptask->compute_dag.GetInitState(); - StateNode* ps = tmp_s.CopyOnWrite(); - ps->transform_steps = inp->state->transform_steps; - tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag); - states.push_back(std::move(tmp_s)); - } - - *ret = states; - }); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index 53ea88323cf9..600f4c2ce18e 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -38,7 +38,7 @@ namespace ansor { class LogToFileNode : public MeasureCallbackNode { public: /*! \brief File name for this callback to write log to. */ - std::string filename; + String filename; void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; @@ -57,7 +57,7 @@ class LogToFile : public MeasureCallback { * \brief The constructor. * \param filename File name for this callback to write log. */ - explicit LogToFile(std::string filename); + explicit LogToFile(String filename); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); }; @@ -66,7 +66,7 @@ class LogToFile : public MeasureCallback { class LogReaderNode : public Object { public: /*! \brief File name for this reader to load log from. */ - std::string filename; + String filename; /*! \brief The reading file stream. */ std::ifstream infile; @@ -106,7 +106,7 @@ class LogReader : public ObjectRef { * \brief The constructor. * \param filename File name for this callback to write log. */ - explicit LogReader(std::string filename); + explicit LogReader(String filename); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogReader, ObjectRef, LogReaderNode); }; diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index a6de46b5f631..4cc94236848f 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -37,37 +37,42 @@ namespace tvm { namespace ansor { /********** Reorder **********/ -ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { +ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); node->stage_id = stage_id; + for (const auto& x : after_ids) { + CHECK(x.defined() && x->IsInstance()); + } node->after_ids = after_ids; data_ = std::move(node); } -void ReorderStepNode::ApplyToSchedule(std::vector* stages, +void ReorderStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; + auto stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; CHECK_EQ(after_ids.size(), axes.size()); - std::vector new_axes; + Array new_axes; new_axes.reserve(axes.size()); for (auto i : after_ids) { - new_axes.push_back(axes[i->value]); + new_axes.push_back(axes[i.as()->value]); } stage.reorder(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); + stages->Set(stage_id, std::move(stage)); } -std::string ReorderStepNode::PrintAsPythonAPI(std::vector* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule, - const std::vector& transform_steps) const { - const te::Stage& stage = (*stages)[stage_id]; +String ReorderStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const { + const auto& stage = (*stages)[stage_id]; std::stringstream ss; ss << "s[" << CleanName(stage->op->name) << "].reorder("; for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i]->value]->var->name_hint); + ss << CleanName((*stage_to_axes)[stage][after_ids[i].as()->value]->var->name_hint); if (i != after_ids.size() - 1) { ss << ", "; } @@ -79,11 +84,11 @@ std::string ReorderStepNode::PrintAsPythonAPI(std::vector* stages, } /********** Split **********/ -Array ApplySplitToSchedule(std::vector* stages, StageToAxesMap* stage_to_axes, +Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; + auto stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; Array outs; if (inner_to_outer) { @@ -104,7 +109,7 @@ Array ApplySplitToSchedule(std::vector* stages, StageToAxesM outs.push_back(inner); } - std::vector new_axes; + Array new_axes; new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); if (inner_to_outer) { for (auto x = outs.rbegin(); x != outs.rend(); ++x) { @@ -116,19 +121,21 @@ Array ApplySplitToSchedule(std::vector* stages, StageToAxesM } } new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); + stages->Set(stage_id, std::move(stage)); return outs; } -std::string PrintSplitAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, +std::string PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { - te::Stage& stage = (*stages)[stage_id]; + const auto& stage = (*stages)[stage_id]; auto to_split = (*stage_to_axes)[stage][iter_id]; const auto& func_name = CleanName(stage->op->name); const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); + CHECK_EQ(outs.size(), lengths.size() + 1); std::stringstream ss; int size = static_cast(lengths.size()); @@ -165,53 +172,61 @@ SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array SplitStepNode::ApplyToSchedule(std::vector* stages, +Array SplitStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -std::string SplitStepNode::PrintAsPythonAPI(std::vector* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule, - const std::vector& transform_steps) const { +String SplitStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const { return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } /********** Fuse **********/ -FuseStep::FuseStep(int stage_id, const Array& fused_ids) { +FuseStep::FuseStep(int stage_id, const Array& fused_ids) { auto node = make_object(); node->stage_id = stage_id; + for (const auto& x : fused_ids) { + CHECK(x.defined() && x->IsInstance()); + } node->fused_ids = fused_ids; data_ = std::move(node); } -IterVar FuseStepNode::ApplyToSchedule(std::vector* stages, +IterVar FuseStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; + auto stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; Array to_fuse; - for (auto i : fused_ids) { - to_fuse.push_back(axes[i->value]); + for (const auto& i : fused_ids) { + to_fuse.push_back(axes[i.as()->value]); } IterVar fused_axis; stage.fuse(to_fuse, &fused_axis); - std::vector new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()->value); + + Array new_axes; + new_axes.insert(new_axes.end(), axes.begin(), + axes.begin() + fused_ids.front().as()->value); new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back()->value + 1, axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back().as()->value + 1, + axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + stages->Set(stage_id, std::move(stage)); return fused_axis; } -std::string FuseStepNode::PrintAsPythonAPI(std::vector* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule, - const std::vector& transform_steps) const { +String FuseStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const { const auto& stage = (*stages)[stage_id]; std::stringstream to_fuse; for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]->value]->var->name_hint); + to_fuse << CleanName( + (*stage_to_axes)[stage][fused_ids[i].as()->value]->var->name_hint); if (i != fused_ids.size() - 1) { to_fuse << ", "; } diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index b2c1a5896aa7..af0cf4038399 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -56,7 +56,7 @@ namespace tvm { namespace ansor { -typedef std::unordered_map, ObjectHash, ObjectEqual> +typedef std::unordered_map, ObjectHash, ObjectEqual> StageToAxesMap; class Step; @@ -78,9 +78,9 @@ class StepNode : public Object { * \param transform_steps Transform steps of the target state. * \return Python schedule code. */ - virtual std::string PrintAsPythonAPI(std::vector* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule, - const std::vector& transform_steps) const = 0; + virtual String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const = 0; static constexpr const char* _type_key = "ansor.Step"; TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); @@ -98,18 +98,18 @@ class ReorderStepNode : public StepNode { * \brief The iterator ids after reorder. * This array should specify the order of all iterators. */ - Array after_ids; + Array after_ids; /*! * \brief Apply the current state to tvm.schedule * \param stages A pointer to `te::Stage` vector. * \param stage_to_axes A pointer to StageToAxesMap. */ - void ApplyToSchedule(std::vector* stages, StageToAxesMap* stage_to_axes) const; + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const final; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const final; static constexpr const char* _type_key = "ansor.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); @@ -126,7 +126,7 @@ class ReorderStep : public Step { * \param stage_id The index of the target stage. * \param after_ids The index of the iterators after reorder. */ - ReorderStep(int stage_id, const Array& after_ids); + ReorderStep(int stage_id, const Array& after_ids); TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; @@ -155,12 +155,12 @@ class SplitStepNode : public StepNode { * \param stage_to_axes A pointer to StageToAxesMap. * \return The iterator results after split. */ - Array ApplyToSchedule(std::vector* stages, + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const final; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const final; static constexpr const char* _type_key = "ansor.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); @@ -189,7 +189,7 @@ class SplitStep : public Step { class FuseStepNode : public StepNode { public: /*! \brief The ids of iterators to fuse. */ - Array fused_ids; + Array fused_ids; /*! * \brief Apply the current state to tvm.schedule @@ -197,11 +197,11 @@ class FuseStepNode : public StepNode { * \param stage_to_axes A pointer to StageToAxesMap. * \return The iterator result after fuse. */ - tir::IterVar ApplyToSchedule(std::vector* stages, StageToAxesMap* stage_to_axes) const; + tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - std::string PrintAsPythonAPI(std::vector* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const final; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule, + const std::vector& transform_steps) const final; static constexpr const char* _type_key = "ansor.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); @@ -218,7 +218,7 @@ class FuseStep : public Step { * \param stage_id The index of the target stage. * \param fused_ids The index of the target iterators to be fused. */ - FuseStep(int stage_id, const Array& fused_ids); + FuseStep(int stage_id, const Array& fused_ids); TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; @@ -238,16 +238,18 @@ struct hash<::tvm::ansor::Step> { size_t ret = ::dmlc::HashCombine(1, std::hash()(ps->stage_id)); for (const auto& x : ps->after_ids) { CHECK(x.defined()); - ret = ::dmlc::HashCombine(ret, x->value); + const auto& pint = x.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); } return ret; } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { size_t ret = ::dmlc::HashCombine(2, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->inner_to_outer))); - for (const auto& len : ps->lengths) { - if (len.defined()) { - auto pint = len.as<::tvm::tir::IntImmNode>(); + for (const auto& x : ps->lengths) { + if (x.defined()) { + const auto& pint = x.as<::tvm::tir::IntImmNode>(); CHECK(pint != nullptr); ret = ::dmlc::HashCombine(ret, pint->value); } else { @@ -259,7 +261,9 @@ struct hash<::tvm::ansor::Step> { size_t ret = ::dmlc::HashCombine(3, std::hash()(ps->stage_id)); for (const auto& x : ps->fused_ids) { CHECK(x.defined()); - ret = ::dmlc::HashCombine(ret, x->value); + const auto& pint = x.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); } return ret; } else { diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 7d76828a24cb..c8d18d3109b2 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -82,36 +82,23 @@ struct hash> { namespace tvm { namespace ansor { -/********** Utilities for std::vector, std::set, std::string **********/ -/*! \brief Get the first appearance index of elements in a array type object */ -template -inline void GetIndices(const ArrayT0& array, const ArrayT1& to_locate, std::vector* indices) { - for (const auto& v : to_locate) { - auto it = std::find(array.begin(), array.end(), v); - if (it != array.end()) { - indices->push_back(it - array.begin()); - } else { - LOG(FATAL) << "Cannot find the item"; - } - } -} - -/*! \brief Get the first appearance index of elements in a array type object */ -template -inline void GetIndices(const ArrayT0& array, const ArrayT1& to_locate, Array* indices) { +/********** Utilities for Array, std::string **********/ +/*! \brief Get the first appearance index of elements in an Array */ +template +inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); if (it != array.end()) { - indices->push_back(IntImm(tvm::DataType::Int(32), it - array.begin())); + indices->push_back(static_cast(it - array.begin())); } else { LOG(FATAL) << "Cannot find the item"; } } } -/*! \brief Get the first appearance index of an element in a vector */ +/*! \brief Get the first appearance index of an element in an Array */ template -inline int GetIndex(const std::vector& array, const T& to_locate) { +inline int GetIndex(const Array& array, const T& to_locate) { for (size_t i = 0; i < array.size(); ++i) { if (array[i] == to_locate) { return i; @@ -121,27 +108,6 @@ inline int GetIndex(const std::vector& array, const T& to_locate) { return -1; } -/*! \brief Get the first appearance index of an element in a vector */ -template -inline int GetIndex(const ArrayT& array, const T& to_locate) { - for (size_t i = 0; i < array.size(); ++i) { - if (array[i] == to_locate) { - return i; - } - } - LOG(FATAL) << "Cannot find the item"; - return -1; -} - -/*! \brief Delete an element in a vector */ -template -inline void DeleteItem(std::vector* array, const T& to_delete) { - auto iter = std::find(array->begin(), array->end(), to_delete); - if (iter != array->end()) { - array->erase(iter); - } -} - /*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); From f40c7af7d98f1541918f7c7b537b8d4c5bc0f50b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 2 Jul 2020 17:47:00 +0800 Subject: [PATCH 57/78] Add init_state to ComputeDAG --- python/tvm/ansor/compute_dag.py | 2 +- src/ansor/compute_dag.cc | 12 ++++-------- src/ansor/compute_dag.h | 17 ++++------------- src/ansor/loop_state.h | 18 ++++++++++-------- src/ansor/search_policy/empty_policy.cc | 2 +- 5 files changed, 20 insertions(+), 31 deletions(-) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index aa4626ed2153..ddc91125530e 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -59,7 +59,7 @@ def get_init_state(self): state : State The initial State without any transform steps. """ - return State(_ffi_api.ComputeDAGGetInitState(self), self) + return State(self.init_state, self) def apply_steps_from_state(self, state): """ diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 48e8a8149905..f4fac24b0aa3 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -242,8 +242,6 @@ ComputeDAG::ComputeDAG(Array tensors) { data_ = std::move(node); } -State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } - std::pair > ComputeDAG::ApplySteps( const Array& transform_steps) const { Array stages; @@ -295,7 +293,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } State ComputeDAG::ReplayAndInferBound(const Array& transform_steps) const { - State ret_state = GetInitState(); + State ret_state = operator->()->init_state; StateNode* pstate = ret_state.CopyOnWrite(); pstate->transform_steps = transform_steps; ret_state.DoSteps(transform_steps, *this); @@ -314,12 +312,12 @@ State ComputeDAG::InferBound(const State& state) const { return ret_state; } -void ComputeDAG::InferBound(std::vector* states) const { - std::vector out_states(states->size(), State()); +void ComputeDAG::InferBound(Array* states) const { + Array out_states(states->size(), State()); auto worker_func = [&states, &out_states, this](int idx) { try { - out_states[idx] = this->InferBound((*states)[idx]); + out_states.Set(idx, this->InferBound((*states)[idx])); } catch (dmlc::Error& e) { LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] << "\n" @@ -484,8 +482,6 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAG").set_body_typed([](Array tens return ComputeDAG(tensors); }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState").set_body_method(&ComputeDAG::GetInitState); - TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { te::Schedule sch; diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 5c8cb649cd1e..9260707f9012 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -30,16 +30,12 @@ #include #include -#include -#include "transform_step.h" +#include "loop_state.h" namespace tvm { namespace ansor { -class StateNode; -class State; - /*! * \brief Update stage and axes mapping during replay. * \param stage A `te::Stage`. @@ -57,12 +53,13 @@ class ComputeDAGNode : public Object { /*! \brief Number of total float operations for this ComputeDAG. */ double flop_ct; /*! \brief The initial state without any transform steps. */ - ObjectRef init_state; + State init_state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tensors", &tensors); v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); + v->Visit("init_state", &init_state); } static constexpr const char* _type_key = "ansor.ComputeDAG"; @@ -118,13 +115,7 @@ class ComputeDAG : public ObjectRef { * Return the new states inplace. * \param states A pointer to a State vector, States are updated inplace. */ - void InferBound(std::vector* states) const; - - /*! - * \brief Get the init state. - * \return The init state. - */ - State GetInitState() const; + void InferBound(Array* states) const; TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 819a2c37fade..726e2688f9ce 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -44,7 +44,6 @@ #include #include -#include "compute_dag.h" #include "transform_step.h" namespace tvm { @@ -52,6 +51,8 @@ namespace ansor { using namespace tvm::tir; +class ComputeDAG; + /*! \brief The type of a stage. */ enum StageType { /*! \brief A placeholder stage. */ @@ -242,22 +243,23 @@ class StateNode : public Object { Array transform_steps; /*! \brief Indicate whether this state has unfilled tile sizes. */ bool complete; - /*! - * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the - * stage structure of the ComputeDAG, for exp. CacheReadStep/CacheWriteStep(Will be added later). - * The default value is an empty NodeRef. (means no modification to the original DAG) - */ - ComputeDAG task_dag; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("stages", &stages); v->Visit("transform_steps", &transform_steps); v->Visit("complete", &complete); - v->Visit("task_dag", &task_dag); } static constexpr const char* _type_key = "ansor.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); + + private: + /*! + * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the + * stage structure of the ComputeDAG, for exp. CacheReadStep/CacheWriteStep(Will be added later). + * The default value is an empty ObjectRef. (means no modification to the original DAG) + */ + ObjectRef current_compute_dag; }; /*! diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index 53cafd6524f3..659e0441d940 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -82,7 +82,7 @@ Array EmptyPolicyNode::SearchOneRound() { Array res; // 1. We will process `Program sampling` first to generate several initial schedules - res.push_back(cur_task->compute_dag.GetInitState()); + res.push_back(cur_task->compute_dag->init_state); // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule // with best performance From 0a24daf4c5b07f27df32fca989d1a83fcaaef1b6 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 2 Jul 2020 18:45:13 +0800 Subject: [PATCH 58/78] Update --- python/tvm/ansor/auto_schedule.py | 2 +- python/tvm/ansor/workload_registry.py | 16 +++++-- src/ansor/compute_dag.cc | 19 +++++--- src/ansor/loop_state.cc | 17 ++------ src/ansor/loop_state.h | 14 +----- src/ansor/transform_step.cc | 20 ++++----- src/ansor/transform_step.h | 62 +++++++++++++++------------ src/ansor/utils.h | 15 ------- 8 files changed, 73 insertions(+), 92 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 6379e99bc41a..bb28f2c78727 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -18,7 +18,7 @@ """ User interface for Ansor auto-scheduler. -The basic schedule search process for Ansor is design to be: +The basic schedule search process for Ansor is designed to be: `Program sampling` -> `Performance Tuning`. In `Program sampling`, we use some predefined or heuristic rules to generate several initial diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 3dae9d15a9d3..0084102a8f75 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -44,6 +44,11 @@ def register_workload_by_func(func): The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + Parameters + ---------- + func : Function + The target function that returns the compute declaration Tensors. + Examples -------- @ansor.register_workload_by_func @@ -54,9 +59,11 @@ def matmul(N, M, K): C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] """ + assert callable(func) func_name = func.__name__ if func_name in WORKLOAD_FUNC_REGISTRY: raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = func return func @@ -66,8 +73,9 @@ def make_workload_key_by_func(func, args): Parameters ---------- - func : Function + func : Union[Function, str] The target function that returns the compute declaration Tensors. + Can be the a function or the function name. args : Args The args of the target function. @@ -76,8 +84,6 @@ def make_workload_key_by_func(func, args): workload_key : Str The workload key of the target function. """ - args = serialize_args(args) - if callable(func): func_name = func.__name__ elif isinstance(func, str): @@ -89,6 +95,8 @@ def make_workload_key_by_func(func, args): raise ValueError("%s is not registered. " % func, "Please register it with @ansor.register_workload_by_func") + args = serialize_args(args) + return json.dumps((func_name,) + args) @@ -136,7 +144,7 @@ def workload_key_to_tensors(workload_key): return lookup(*args) -def dump_workload_func_registry(filename): +def save_workload_func_registry(filename): """ Dump workload function registry to a pickle binary file. Parameters diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index f4fac24b0aa3..b1bd97ac0f1a 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -268,7 +268,6 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } std::stringstream ss; - for (const auto& stage : stages) { if (stage->op->IsInstance()) { for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { @@ -283,10 +282,17 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } - std::vector step_vector(transform_steps.begin(), transform_steps.end()); // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, step_vector); + if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else if (auto ps = step.as()) { + ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } } return ss.str(); @@ -365,8 +371,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { auto find_res = bounds.find(axis); if (find_res != bounds.end()) { - new_iters.push_back(Iterator(iter->name, (*find_res).second, iter->iter_type, - iter->annotation, &iter->ori_iters, iter->attr)); + new_iters.push_back( + Iterator(iter->name, (*find_res).second, iter->iter_type, iter->annotation)); } else { LOG(FATAL) << "Infer bound fails"; } @@ -412,8 +418,7 @@ std::pair > ComputeDAG::ReplaySteps( } // Call each step's ApplyToSchedule method // Note: some steps have extra parameters that must be passed and they may need different - // return value, so the ApplyToSchedule is not able to be merged to single interface like - // PrintAsPythonAPI does + // return value, so the ApplyToSchedule is not able to be merged to single interface if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index f084cd62a749..d02aee82f9ba 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -42,17 +42,13 @@ TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); /********** Iterator **********/ -Iterator::Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters, String attr) { +Iterator::Iterator(String name, Range range, IteratorType iter_type, + IteratorAnnotation annotation) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); node->iter_type = iter_type; node->annotation = annotation; - if (ori_iters != nullptr) { - node->ori_iters = *ori_iters; - } - node->attr = std::move(attr); data_ = std::move(node); } @@ -231,7 +227,6 @@ Iterator State::DoFuseStep(const FuseStep& step) { PrimExpr new_extent = 1; IteratorType new_iter_type = kSpecial; - std::vector ori_iters; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { CHECK_EQ(step->fused_ids[i].as()->value, @@ -239,7 +234,6 @@ Iterator State::DoFuseStep(const FuseStep& step) { } const Iterator& it = stage->iters[step->fused_ids[i].as()->value]; - ori_iters.push_back(it); new_name = new_name + it->name + "@"; if (it->range.defined() && new_extent.defined()) { @@ -261,7 +255,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::FromMinExtent(0, new_extent); } - Iterator new_it = Iterator(new_name, range, new_iter_type, kNone, &ori_iters); + Iterator new_it = Iterator(new_name, range, new_iter_type, kNone); Array new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front().as()->value); @@ -368,9 +362,6 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b } else { *os << iter->name << " (None)"; } - if (!iter->attr.empty()) { - *os << " " << iter->attr; - } *os << "\n"; indent += 2; @@ -386,7 +377,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b // Print state to ostream void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { // Gather placeholders - std::vector placeholders; + Array placeholders; for (const auto& stage : node->stages) { if (stage->op_type == kPlaceholder) { placeholders.push_back(stage->op->name); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 726e2688f9ce..c0288fbeccb6 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -42,7 +42,6 @@ #include #include -#include #include "transform_step.h" @@ -107,9 +106,6 @@ enum IteratorAnnotation { kTensorized = 9 }; -// forward declaration -class Iterator; - /*! * \brief A for loop iterator * Similar to tvm::IterVar in `include/tvm/tir/expr.h` @@ -124,15 +120,10 @@ class IteratorNode : public Object { IteratorType iter_type; /*! \brief The annotation type of this iterator. */ IteratorAnnotation annotation; - /*! \brief The original iterators before fusion. */ - std::vector ori_iters; - /*! \brief The extra attributes of this iterator. */ - String attr; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); - v->Visit("attr", &attr); } static constexpr const char* _type_key = "ansor.Iterator"; @@ -151,11 +142,8 @@ class Iterator : public ObjectRef { * \param range The target range of this iterator. * \param iter_type The iterator type of this iterator. * \param annotation The annotation type of this iterator. - * \param ori_iters The original iterators before fusion. - * \param attr The extra attribute of this iterator. */ - Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr, String attr = ""); + Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation); TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); }; diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 4cc94236848f..c8fb5bc185ce 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -64,9 +64,8 @@ void ReorderStepNode::ApplyToSchedule(Array* stages, stages->Set(stage_id, std::move(stage)); } -String ReorderStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const { +String ReorderStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; std::stringstream ss; @@ -127,9 +126,8 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st return outs; } -std::string PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - int stage_id, int iter_id, const Array& lengths, - bool inner_to_outer) { +String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, + int iter_id, const Array& lengths, bool inner_to_outer) { const auto& stage = (*stages)[stage_id]; auto to_split = (*stage_to_axes)[stage][iter_id]; const auto& func_name = CleanName(stage->op->name); @@ -177,9 +175,8 @@ Array SplitStepNode::ApplyToSchedule(Array* stages, return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -String SplitStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const { +String SplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } @@ -218,9 +215,8 @@ IterVar FuseStepNode::ApplyToSchedule(Array* stages, return fused_axis; } -String FuseStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const { +String FuseStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; std::stringstream to_fuse; diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index af0cf4038399..7e0b23e0567a 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -34,7 +34,7 @@ * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. - * 5. Add serialization support in `struct Handler >` + * 5. Add serialization support in `struct Handler >` * in `serialization.cc`. * 6. Add hash support in `struct hash<::tvm::ansor::Step>`. (search for this function in this file) * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test. @@ -47,9 +47,7 @@ #include #include -#include #include -#include #include "utils.h" @@ -59,8 +57,6 @@ namespace ansor { typedef std::unordered_map, ObjectHash, ObjectEqual> StageToAxesMap; -class Step; - /*! * \brief The base class for a transformation step. Each step has its corresponding tvm.te * schedule primitives. @@ -70,22 +66,14 @@ class StepNode : public Object { /*! \brief The index of the target stage. */ int stage_id; - /*! - * \brief Print step as equivalent python schedule API. - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. - * \param schedule A pointer to `te::Schedule`. - * \param transform_steps Transform steps of the target state. - * \return Python schedule code. - */ - virtual String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const = 0; - static constexpr const char* _type_key = "ansor.Step"; TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); }; +/*! + * \brief Managed reference to StepNode. + * \sa StepNode + */ class Step : public ObjectRef { public: TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); @@ -107,9 +95,13 @@ class ReorderStepNode : public StepNode { */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const final; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "ansor.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); @@ -158,9 +150,13 @@ class SplitStepNode : public StepNode { Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const final; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "ansor.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); @@ -199,9 +195,13 @@ class FuseStepNode : public StepNode { */ tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule, - const std::vector& transform_steps) const final; + /*! + * \brief Print step as equivalent python schedule API. + * \param stages A pointer to `te::Stage` vector. + * \param stage_to_axes A pointer to StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; static constexpr const char* _type_key = "ansor.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); @@ -246,7 +246,15 @@ struct hash<::tvm::ansor::Step> { } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { size_t ret = ::dmlc::HashCombine(2, ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->inner_to_outer))); + ::dmlc::HashCombine(std::hash()(ps->iter_id), + std::hash()(ps->inner_to_outer)))); + if (ps->extent.defined()) { + const auto& pint = ps->extent.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); + } else { + ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number + } for (const auto& x : ps->lengths) { if (x.defined()) { const auto& pint = x.as<::tvm::tir::IntImmNode>(); diff --git a/src/ansor/utils.h b/src/ansor/utils.h index c8d18d3109b2..5a9c0a26d1cd 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -62,21 +62,6 @@ struct hash> { } }; -/*! \brief Hash function for std::vector */ -template -struct hash> { - std::size_t operator()(const std::vector& vec) const { - if (vec.empty()) { - return 0; - } - std::size_t ret = std::hash()(vec[0]); - for (size_t i = 1; i < vec.size(); ++i) { - ret = ::dmlc::HashCombine(ret, std::hash()(vec[i])); - } - return ret; - } -}; - } // namespace std namespace tvm { From a45fd898b9d709e412f6d9ec4b61e94c9429345f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 2 Jul 2020 20:39:34 +0800 Subject: [PATCH 59/78] Update some unordered_map to Map --- src/ansor/compute_dag.cc | 6 ++---- src/ansor/compute_dag.h | 7 ------- src/ansor/serialization.cc | 4 ++-- src/ansor/transform_step.cc | 16 ++++++++-------- src/ansor/transform_step.h | 10 ++++------ src/ansor/utils.h | 4 ---- 6 files changed, 16 insertions(+), 31 deletions(-) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index b1bd97ac0f1a..3d5776cbded9 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -32,10 +32,8 @@ #include #include -#include #include #include -#include #include #include "loop_state.h" @@ -51,14 +49,14 @@ TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Update stage to axis mapping void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { if (auto pop = stage->op.as()) { - Array& axes = (*stage_to_axes)[stage]; - axes.clear(); + Array axes; for (const auto& axis : pop->axis) { axes.push_back(axis); } for (const auto& axis : pop->reduce_axis) { axes.push_back(axis); } + stage_to_axes->Set(stage, std::move(axes)); } else if (stage->op->IsInstance()) { {} // do nothing } else { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 9260707f9012..a2b50b902733 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -36,13 +36,6 @@ namespace tvm { namespace ansor { -/*! - * \brief Update stage and axes mapping during replay. - * \param stage A `te::Stage`. - * \param stage_to_axes A pointer to StageToAxesMap. - */ -void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap* stage_to_axes); - /*! \brief Computation declaration graph. */ class ComputeDAGNode : public Object { public: diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index bf6dc2b2875d..a29df425de0b 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -30,7 +30,6 @@ #include #include #include -#include #include #include @@ -146,7 +145,8 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { lengths.push_back(::tvm::PrimExpr(i)); } data->push_back( - ::tvm::ansor::SplitStep(stage_id, iter_id, extent, lengths, inner_to_outer)); + ::tvm::ansor::SplitStep(stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, + lengths, inner_to_outer)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index c8fb5bc185ce..2e1fbfb9cdbc 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -50,7 +50,7 @@ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { void ReorderStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; - const Array& axes = (*stage_to_axes)[stage]; + const Array& axes = stage_to_axes->at(stage); CHECK_EQ(after_ids.size(), axes.size()); Array new_axes; @@ -60,7 +60,7 @@ void ReorderStepNode::ApplyToSchedule(Array* stages, } stage.reorder(new_axes); - (*stage_to_axes)[stage] = std::move(new_axes); + stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); } @@ -87,7 +87,7 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { auto stage = (*stages)[stage_id]; - const Array& axes = (*stage_to_axes)[stage]; + const Array& axes = stage_to_axes->at(stage); Array outs; if (inner_to_outer) { @@ -121,7 +121,7 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st } new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); + stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); return outs; } @@ -129,7 +129,7 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { const auto& stage = (*stages)[stage_id]; - auto to_split = (*stage_to_axes)[stage][iter_id]; + auto to_split = stage_to_axes->at(stage)[iter_id]; const auto& func_name = CleanName(stage->op->name); const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); @@ -194,7 +194,7 @@ FuseStep::FuseStep(int stage_id, const Array& fused_ids) { IterVar FuseStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; - const Array& axes = (*stage_to_axes)[stage]; + const Array& axes = stage_to_axes->at(stage); Array to_fuse; for (const auto& i : fused_ids) { @@ -210,7 +210,7 @@ IterVar FuseStepNode::ApplyToSchedule(Array* stages, new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back().as()->value + 1, axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); + stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); return fused_axis; } @@ -222,7 +222,7 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, for (size_t i = 0; i < fused_ids.size(); ++i) { to_fuse << CleanName( - (*stage_to_axes)[stage][fused_ids[i].as()->value]->var->name_hint); + stage_to_axes->at(stage)[fused_ids[i].as()->value]->var->name_hint); if (i != fused_ids.size() - 1) { to_fuse << ", "; } diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 7e0b23e0567a..c3144f2c10d7 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -47,14 +47,12 @@ #include #include -#include - #include "utils.h" namespace tvm { namespace ansor { -typedef std::unordered_map, ObjectHash, ObjectEqual> +typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; /*! @@ -249,9 +247,9 @@ struct hash<::tvm::ansor::Step> { ::dmlc::HashCombine(std::hash()(ps->iter_id), std::hash()(ps->inner_to_outer)))); if (ps->extent.defined()) { - const auto& pint = ps->extent.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); + const auto& pint = ps->extent.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); } else { ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number } diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 5a9c0a26d1cd..c7fb7204ac69 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -32,13 +32,9 @@ #include #include #include -#include -#include -#include #include #include #include -#include #include #include From bfc66639917bbf699617ca0e93f0c909f6519286 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 2 Jul 2020 20:43:09 +0800 Subject: [PATCH 60/78] clang-format fix --- src/ansor/serialization.cc | 5 ++--- src/ansor/transform_step.h | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index a29df425de0b..0df937065e74 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -144,9 +144,8 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { for (const auto& i : int_list) { lengths.push_back(::tvm::PrimExpr(i)); } - data->push_back( - ::tvm::ansor::SplitStep(stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, - lengths, inner_to_outer)); + data->push_back(::tvm::ansor::SplitStep( + stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index c3144f2c10d7..f9271e73c3c8 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -52,8 +52,7 @@ namespace tvm { namespace ansor { -typedef Map, ObjectHash, ObjectEqual> - StageToAxesMap; +typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; /*! * \brief The base class for a transformation step. Each step has its corresponding tvm.te From eb02e77cf73772607eb057dccd13c50effe97ef0 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 3 Jul 2020 15:22:31 +0800 Subject: [PATCH 61/78] Comments addressed Delete ReplayAndInferBound Delete ReplaySteps & InferBoundCommon --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/auto_schedule.py | 47 ++--- python/tvm/ansor/compute_dag.py | 23 ++- python/tvm/ansor/loop_state.py | 32 ++-- python/tvm/ansor/measure.py | 20 +- python/tvm/ansor/utils.py | 2 +- src/ansor/auto_schedule.cc | 39 ++-- src/ansor/auto_schedule.h | 37 ++-- src/ansor/compute_dag.cc | 222 +++++++++++------------ src/ansor/compute_dag.h | 59 +++--- src/ansor/loop_state.cc | 9 +- src/ansor/loop_state.h | 69 ++++--- src/ansor/measure.cc | 20 +- src/ansor/measure.h | 70 +++---- src/ansor/search_policy/search_policy.cc | 2 +- src/ansor/search_policy/search_policy.h | 19 +- src/ansor/search_task.h | 6 +- src/ansor/serialization.h | 12 +- src/ansor/transform_step.h | 28 +-- 19 files changed, 367 insertions(+), 351 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 480fb3422624..4fcf1008a2ea 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -26,7 +26,7 @@ # Shortcut from .compute_dag import ComputeDAG -from .auto_schedule import SearchTask, TuneOption, HardwareParams, \ +from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ auto_schedule, EmptyPolicy from .measure import MeasureInput, LocalBuilder, LocalRunner from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index bb28f2c78727..9243263c4a4b 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -21,10 +21,11 @@ The basic schedule search process for Ansor is designed to be: `Program sampling` -> `Performance Tuning`. -In `Program sampling`, we use some predefined or heuristic rules to generate several initial -schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model -and evolutionary search to seek for schedules with the best performance. Candidate schedules will -be measured in the target hardware. +In `Program sampling`, we use some predefined precise or heuristic rules to generate several +initial schedules. Based on these initial starting points, we perform `Performance Tuning` which +uses cost model based evolutionary search to select schedules with the best performance. + +Candidate schedules are measured against the specific hardware target. """ import tvm._ffi @@ -36,10 +37,9 @@ @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): - """ The parameters of target hardware, this is used to guide the search process of - SearchPolicy. + """ The parameters of target hardware used to guide the search process of SearchPolicy. - TODO(...): This is considering to merge with the new Target: + TODO(jcf94): This is considering to merge with the new Target: https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 Parameters @@ -64,14 +64,14 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, @tvm._ffi.register_object("ansor.SearchTask") class SearchTask(Object): - """ The meta-information of a search task. + """ The computation information and hardware parameters for a specific schedule search task. Parameters ---------- dag : ComputeDAG - The ComputeDAG for target compute declaration. + The ComputeDAG for the target compute declaration. workload_key : str - The workload key for target compute declaration. + The workload key for the target compute declaration. target : tvm.target.Target The target device of this search task. target_host : Optional[tvm.target.Target] @@ -88,7 +88,7 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): - """ The base class for search policy """ + """ The base class of search policies. """ @tvm._ffi.register_object("ansor.EmptyPolicy") @@ -100,8 +100,8 @@ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) -@tvm._ffi.register_object("ansor.TuneOption") -class TuneOption(Object): +@tvm._ffi.register_object("ansor.TuningOptions") +class TuningOptions(Object): """ This controls the options of performance tuning. Parameters @@ -122,10 +122,10 @@ class TuneOption(Object): We have: `num_search_rounds` = `num_measure_trials` // `num_measures_per_round` verbose: int = 1 Verbosity level. 0 for silent, 1 to output information during schedule search. - builder: Union[Builder, str] = 'local' - Builder which builds the program. - runner: Union[Runner, str] = 'local' - Runner which runs the program and measures time costs. + builder: Union[ProgramBuilder, str] = 'local' + ProgramBuilder which builds the program. + runner: Union[ProgramRunner, str] = 'local' + ProgramRunner which runs the program and measures time costs. measure_callbacks: Optional[List[MeasureCallback]] Callback functions called after each measure. Candidates: @@ -156,12 +156,12 @@ def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_rou pre_search_callbacks = [] if pre_search_callbacks is None else pre_search_callbacks self.__init_handle_by_constructor__( - _ffi_api.TuneOption, num_measure_trials, early_stopping, num_measures_per_round, + _ffi_api.TuningOptions, num_measure_trials, early_stopping, num_measures_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks) def auto_schedule(task, target, target_host=None, search_policy='default', - hardware_params=None, tune_option=None): + hardware_params=None, tuning_options=None): """ Do auto scheduling for a computation declaration. The task parameter can be a `string` as workload_key, or directly @@ -179,7 +179,7 @@ def auto_schedule(task, target, target_host=None, search_policy='default', The search policy to be used for schedule search. hardware_params : Optional[HardwareParams] The hardware parameters of this schedule search. - tune_option : Optional[TuneOption] + tuning_options : Optional[TuningOptions] Tuning and measurement options. Returns @@ -194,13 +194,14 @@ def auto_schedule(task, target, target_host=None, search_policy='default', else: raise ValueError("Invalid search policy: " + search_policy) - tune_option = tune_option if tune_option else TuneOption() + tuning_options = tuning_options if tuning_options else TuningOptions() if isinstance(task, str): dag = ComputeDAG(task) task = SearchTask(dag, task, target, target_host, hardware_params) elif not isinstance(task, SearchTask): - raise ValueError("Invalid task: " + task + ". Expect a string or SearchTask") + raise ValueError("Invalid task: " + task + + " . `ansor.auto_schedule` expects a `str` or `SearchTask`.") - sch, tensors = _ffi_api.AutoSchedule(task, search_policy, tune_option) + sch, tensors = _ffi_api.AutoSchedule(task, search_policy, tuning_options) return sch, tensors diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index ddc91125530e..fa836562db60 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -33,7 +33,18 @@ @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): """ - Computation declaration graph. + The Ansor computational graph and related program analyses. + + We convert a compute declaration described by `tvm.compute` (could be a single operator or a + subgraph) to a ComputeDAG. It keeps the input/output tensors of the target compute declaration, + a list of all related operations in topo order as well as a set of analyses over each operation + stage (e.g. the total float operation count, consumer/producer relations of each operation + stage, whether a operation stage should be tiled/compute inlined ...). These analyses can + help the search policy to do some specific decisions during schedule search process. + + ComputeDAG is also responsible for the interaction between Ansor LoopState and TVM schedule + (e.g. applying the LoopState transform steps to TVM schedule, providing LoopState with extra + information get from TVM schedule ...). Parameters ---------- @@ -52,7 +63,7 @@ def __init__(self, compute): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) def get_init_state(self): - """ Get init state of this ComputeDAG. + """ Get the init state of this ComputeDAG. Returns ------- @@ -63,7 +74,7 @@ def get_init_state(self): def apply_steps_from_state(self, state): """ - Apply transform steps according to the history of a State. + Apply the history transform steps of a State to TVM schedule. Parameters ---------- @@ -96,14 +107,14 @@ def print_python_code_from_state(self, state): def infer_bound_from_state(self, state): """ - Infer bound for a state using TVM schedule. + Infer and fill the bound of all iterators of a state using TVM schedule. State api supports to define a split step with its split factor to be a blank placeholder, so sometimes we may get a State will incomplete iterator extent information. And another situation is after some steps (for exp. compute_at), it may be hard to track the extent change of all iterators. - We perform infer bound using TVM schedule and fill the State with those informations. After + We perform infer bound using TVM schedule and fill the State with those information. After applying this methods, the State is guaranteed to have complete interator extent information. @@ -121,7 +132,7 @@ def infer_bound_from_state(self, state): return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) def __hash__(self): - # TODO(...): Implement this more carefully and move this to c++ as a member function + # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function # of ComputeDAG str_key = '' for op in self.ops: diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index a1420bf9b30e..dbf3d678263b 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -17,20 +17,28 @@ # pylint: disable=unused-import """ -The definition of the "state" in search. A state consists a current loop structure -and the transform history to reach its current loop structure. -To enable flexible manipulation of the loop structures, we implemented a lightweight loop -structure IR (Intermediate Representation) based on the original TVM IR but specifically -for schedule search. - -We don't use the existing TVM IR but to extend a new Sketch IR on it is because: -1. We want fast incremental change to the loop structures; +The definition of the "state" in search. + +Each LoopState corresponds to a specific schedule for its target ComputeDAG. +A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to +construct the loop structure. +The loop structure keeps a preview of how the schedule will finally look like after lowering the +current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...). +During the schedule search process, the loop structure can provide search policy with necessary +information on how to perform further operations with the current state. +The transform history is a sequence of TransformStep which will finally be mapped to schedule +primitives. The steps can also be used for serialization of a state. + +The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. +We don't use the existing TVM IR but to extend a new structure on it is because: +1. We want fast incremental change to the loop structures, search policy needs to get the immediate +loop structures update rather than after TVM lowering; 2. We want serializable transform history for replay, backtracking, and mutation; 3. We may create some macro schedule primitives that represent the combination of several TVM schedule primitives. -After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. -Because we share a lot common objects during search, the transformation is implemented in +When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. +Since we share a lot of common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. """ @@ -60,9 +68,9 @@ def __eq__(self, other): class State: """ A state in the search process. It consists of the current loop structure - and the history steps to reach this state. + and a history of transformations used to construct it. - Each State corresponds to a specific schedule for the target ComputeDAG. + Each State corresponds to a specific schedule for its target ComputeDAG. Parameters ---------- diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 691dcca6f85c..7a3a3d5ec64a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -115,9 +115,9 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): error_msg, all_cost, timestamp) -@tvm._ffi.register_object("ansor.Builder") -class Builder(Object): - """ Base class of Builder. """ +@tvm._ffi.register_object("ansor.ProgramBuilder") +class ProgramBuilder(Object): + """ Base class of ProgramBuilder. """ def build(self, measure_inputs, verbose=1): """ Build programs and return results. @@ -133,12 +133,12 @@ def build(self, measure_inputs, verbose=1): ------- res : List[BuildResult] """ - return _ffi_api.BuilderBuild(self, measure_inputs, verbose) + return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose) -@tvm._ffi.register_object("ansor.Runner") -class Runner(Object): - """ Base class of Runner """ +@tvm._ffi.register_object("ansor.ProgramRunner") +class ProgramRunner(Object): + """ Base class of ProgramRunner """ def run(self, measure_inputs, build_results, verbose=1): """ Run measurement and return results. @@ -156,11 +156,11 @@ def run(self, measure_inputs, build_results, verbose=1): ------- res : List[MeasureResult] """ - return _ffi_api.RunnerRun(self, measure_inputs, build_results, verbose) + return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose) @tvm._ffi.register_object("ansor.LocalBuilder") -class LocalBuilder(Builder): +class LocalBuilder(ProgramBuilder): """ LocalBuilder use local CPU cores to build programs in parallel. Parameters @@ -182,7 +182,7 @@ def __init__(self, @tvm._ffi.register_object("ansor.LocalRunner") -class LocalRunner(Runner): +class LocalRunner(ProgramRunner): """ LocalRunner that uses local CPU/GPU to measures the time cost of programs. Parameters diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 309f63cec93c..9dbcd81f36e7 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Common utilities for ansor""" +""" Common utilities for ansor. """ from typing import Hashable import multiprocessing diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 2b0860a07303..dfaff797a179 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -29,13 +29,13 @@ namespace tvm { namespace ansor { -TVM_REGISTER_NODE_TYPE(TuneOptionNode); +TVM_REGISTER_NODE_TYPE(TuningOptionsNode); -TuneOption::TuneOption(int num_measure_trials, int early_stopping, int num_measures_per_round, - int verbose, Builder builder, Runner runner, - Array measure_callbacks, - Array pre_search_callbacks) { - auto node = make_object(); +TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, + int verbose, ProgramBuilder builder, ProgramRunner runner, + Array measure_callbacks, + Array pre_search_callbacks) { + auto node = make_object(); node->num_measure_trials = num_measure_trials; node->early_stopping = early_stopping; node->num_measures_per_round = num_measures_per_round; @@ -49,32 +49,33 @@ TuneOption::TuneOption(int num_measure_trials, int early_stopping, int num_measu std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuneOption tune_option) { + TuningOptions tuning_options) { // Create a ProgramMeasurer to handle the schedule build and performance measure - ProgramMeasurer measurer = ProgramMeasurer(tune_option->builder, tune_option->runner, - tune_option->measure_callbacks, tune_option->verbose); + ProgramMeasurer measurer = + ProgramMeasurer(tuning_options->builder, tuning_options->runner, + tuning_options->measure_callbacks, tuning_options->verbose); // Search for the best schedule - State state = - search_policy->Search(task, tune_option->num_measure_trials, tune_option->early_stopping, - tune_option->num_measures_per_round, tune_option->verbose, measurer, - tune_option->pre_search_callbacks); + State state = search_policy->Search( + task, tuning_options->num_measure_trials, tuning_options->early_stopping, + tuning_options->num_measures_per_round, tuning_options->verbose, measurer, + tuning_options->pre_search_callbacks); return task->compute_dag.ApplySteps(state->transform_steps); } -TVM_REGISTER_GLOBAL("ansor.TuneOption") +TVM_REGISTER_GLOBAL("ansor.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, - int verbose, Builder builder, Runner runner, + int verbose, ProgramBuilder builder, ProgramRunner runner, Array measure_callbacks, Array pre_search_callbacks) { - return TuneOption(num_measure_trials, early_stopping, num_measures_per_round, verbose, - builder, runner, measure_callbacks, pre_search_callbacks); + return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, + builder, runner, measure_callbacks, pre_search_callbacks); }); TVM_REGISTER_GLOBAL("ansor.AutoSchedule") - .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { + .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuningOptions tuning_options) { te::Schedule sch; Array return_tensors; - std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); + std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tuning_options); return Array{sch, return_tensors}; }); } // namespace ansor diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 130b56f8a54e..8127990ca2ec 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -36,7 +36,7 @@ namespace tvm { namespace ansor { /*! \brief Tuning and measurement options. */ -class TuneOptionNode : public Object { +class TuningOptionsNode : public Object { public: /*! \brief Number of total measurement trials. */ int num_measure_trials; @@ -46,10 +46,10 @@ class TuneOptionNode : public Object { int num_measures_per_round; /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */ int verbose; - /*! \brief Builder which builds the program */ - Builder builder; - /*! \brief Runner which runs the program and measure time costs */ - Runner runner; + /*! \brief ProgramBuilder which builds the program */ + ProgramBuilder builder; + /*! \brief ProgramRunner which runs the program and measure time costs */ + ProgramRunner runner; /*! \brief MeasureCallback functions to be called after each measure batch */ Array measure_callbacks; /*! \brief SearchCallback functions to be called before schedule search */ @@ -66,15 +66,15 @@ class TuneOptionNode : public Object { v->Visit("pre_search_callbacks", &pre_search_callbacks); } - static constexpr const char* _type_key = "ansor.TuneOption"; - TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); + static constexpr const char* _type_key = "ansor.TuningOptions"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object); }; /*! - * \brief Managed reference to TuneOptionNode. - * \sa TuneOptionNode + * \brief Managed reference to TuningOptionsNode. + * \sa TuningOptionsNode */ -class TuneOption : public ObjectRef { +class TuningOptions : public ObjectRef { public: /*! * \brief The constructor @@ -83,28 +83,29 @@ class TuneOption : public ObjectRef { * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule * search. - * \param builder Builder which builds the program. - * \param runner Runner which runs the program and measure time costs. + * \param builder ProgramBuilder which builds the program. + * \param runner ProgramRunner which runs the program and measure time costs. * \param measure_callbacks MeasureCallback functions to be called after each measure batch. * \param pre_search_callbacks SearchCallback functions to be called before schedule search. */ - TuneOption(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, - Builder builder, Runner runner, Array measure_callbacks, - Array pre_search_callbacks); + TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, + ProgramBuilder builder, ProgramRunner runner, + Array measure_callbacks, + Array pre_search_callbacks); - TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); + TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode); }; /*! * \brief Auto schedule search for a given compute declaration, by SearchTask. * \param task The target search task. * \param search_policy The search policy to be used for schedule search. - * \param tune_option Tuning and measurement options. + * \param tuning_options Tuning and measurement options. * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. */ std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuneOption tune_option); + TuningOptions tuning_options); } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 3d5776cbded9..80843c420044 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -46,24 +46,6 @@ using namespace tvm::tir; TVM_REGISTER_NODE_TYPE(ComputeDAGNode); -// Update stage to axis mapping -void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { - if (auto pop = stage->op.as()) { - Array axes; - for (const auto& axis : pop->axis) { - axes.push_back(axis); - } - for (const auto& axis : pop->reduce_axis) { - axes.push_back(axis); - } - stage_to_axes->Set(stage, std::move(axes)); - } else if (stage->op->IsInstance()) { - {} // do nothing - } else { - LOG(FATAL) << "Invalid op " << stage->op; - } -} - // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops void TopoSortOps(const Array& tensors, Array* ops) { @@ -240,11 +222,79 @@ ComputeDAG::ComputeDAG(Array tensors) { data_ = std::move(node); } +// Update the te::stage to tir::IterVar axis mapping +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { + if (auto pop = stage->op.as()) { + Array axes; + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + stage_to_axes->Set(stage, std::move(axes)); + } else if (stage->op->IsInstance()) { + {} // do nothing on Placeholder + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + std::pair > ComputeDAG::ApplySteps( - const Array& transform_steps) const { - Array stages; - StageToAxesMap stage_to_axes; - return ReplaySteps(transform_steps, &stages, &stage_to_axes); + const Array& transform_steps, Array* stages, + StageToAxesMap* stage_to_axes) const { + // Temporal object to be used if the input pointer is nullptr + Array temp_stages; + StageToAxesMap temp_stage_to_axes; + if (stages == nullptr) { + stages = &temp_stages; + } + if (stage_to_axes == nullptr) { + stage_to_axes = &temp_stage_to_axes; + } + Array ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + // Create the initial schedule + te::Schedule schedule = te::create_schedule({ops.back()}); + + // init axes + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages->push_back(stage); + UpdateStageAxis(stage, stage_to_axes); + } + + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; + // Apply the history steps to TVM schedule + for (const auto& step : transform_steps) { + if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { + break; + } + // Call each step's ApplyToSchedule method + // Note: some steps have extra parameters that must be passed and they may need different + // return value, so the ApplyToSchedule is not able to be merged to single interface + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } + } + + return std::make_pair(schedule, operator->()->tensors); } String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { @@ -256,6 +306,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const ops.push_back(op); } } + // Create the initial schedule te::Schedule schedule = te::create_schedule({ops.back()}); // init axes @@ -296,62 +347,31 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const return ss.str(); } -State ComputeDAG::ReplayAndInferBound(const Array& transform_steps) const { - State ret_state = operator->()->init_state; - StateNode* pstate = ret_state.CopyOnWrite(); - pstate->transform_steps = transform_steps; - ret_state.DoSteps(transform_steps, *this); - - InferBoundCommon(pstate); - - return ret_state; -} - State ComputeDAG::InferBound(const State& state) const { - State ret_state = state; - StateNode* pstate = ret_state.CopyOnWrite(); - - InferBoundCommon(pstate); + State ret_state; + StateNode* pstate; - return ret_state; -} - -void ComputeDAG::InferBound(Array* states) const { - Array out_states(states->size(), State()); - - auto worker_func = [&states, &out_states, this](int idx) { - try { - out_states.Set(idx, this->InferBound((*states)[idx])); - } catch (dmlc::Error& e) { - LOG(WARNING) << "InferBound fails on the state:\n" - << (*states)[idx] << "\n" - << e.what() << std::endl; - } - }; - - // Lower states in parallel - ThreadPool& pool = ThreadPool::Global(); - pool.BeginBatch(states->size()); - for (size_t i = 0; i < states->size(); ++i) { - pool.Enqueue(worker_func, i); + if (state->stages.size()) { + ret_state = state; + pstate = ret_state.CopyOnWrite(); + } else { + // If the input state is incomplete with empty operation stage + // create a new state from init_state and update it first + ret_state = operator->()->init_state; + pstate = ret_state.CopyOnWrite(); + pstate->transform_steps = state->transform_steps; + ret_state.DoSteps((*this)); } - pool.WaitBatch(); - *states = std::move(out_states); -} - -void ComputeDAG::InferBoundCommon(StateNode* pstate) const { Array stages; StageToAxesMap stage_to_axes; te::Schedule sch; Array tensors; - Map bounds; - // Replay steps to tvm::Schedule - std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); + std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); sch = sch.normalize(); // Get bound information from TVM schedule - bounds = te::InferBound(sch); + Map bounds = te::InferBound(sch); // Update the state bound information for (size_t i = 0; i < pstate->stages.size(); ++i) { @@ -363,6 +383,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { Array new_iters; new_iters.reserve(stage->iters.size()); + // Get bound information from schedule + // the StageToAxesMap is used to find the corresponding IterVar in TVM schedule result for (size_t j = 0; j < stage->iters.size(); ++j) { const Iterator& iter = stage->iters[j]; const IterVar& axis = stage_to_axes.at(stages[i])[j]; @@ -379,56 +401,32 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { pstate->stages.Set( i, Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs)); } -} -std::pair > ComputeDAG::ReplaySteps( - const Array& transform_steps, Array* stages, - StageToAxesMap* stage_to_axes) const { - Array ops; - for (const auto& op : operator->()->ops) { - if (!op->IsInstance()) { - ops.push_back(op); - } - } - - te::Schedule schedule = te::create_schedule({ops.back()}); - - // init axes - stages->reserve(operator->()->ops.size()); - for (const auto& x : operator->()->ops) { - const te::Stage& stage = schedule.operator[](x); - stages->push_back(stage); - UpdateStageAxis(stage, stage_to_axes); - } + return ret_state; +} - // Use complete rate for the study in the paper - const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); - double complete_rate = -1.0; - if (complete_rate_str) { - complete_rate = std::stod(complete_rate_str); - } - size_t ct = 0; +void ComputeDAG::InferBound(Array* states) const { + Array out_states(states->size(), State()); - // replay history - for (const auto& step : transform_steps) { - if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { - break; - } - // Call each step's ApplyToSchedule method - // Note: some steps have extra parameters that must be passed and they may need different - // return value, so the ApplyToSchedule is not able to be merged to single interface - if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else { - LOG(FATAL) << "Invalid Step"; + auto worker_func = [&states, &out_states, this](int idx) { + try { + out_states.Set(idx, this->InferBound((*states)[idx])); + } catch (dmlc::Error& e) { + LOG(WARNING) << "InferBound fails on the state:\n" + << (*states)[idx] << "\n" + << e.what() << std::endl; } + }; + + // Lower states in parallel + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(states->size()); + for (size_t i = 0; i < states->size(); ++i) { + pool.Enqueue(worker_func, i); } + pool.WaitBatch(); - return std::make_pair(schedule, operator->()->tensors); + *states = std::move(out_states); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -500,7 +498,7 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { - return dag.ReplayAndInferBound(state->transform_steps); + return dag.InferBound(state); }); } // namespace ansor diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index a2b50b902733..7275ac645e20 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -19,9 +19,18 @@ /*! * \file ansor/compute_dag.h - * \brief Compute declaration graph and its related analysis tools. - * ComputeDAG is also responsible for the interaction with the original TVM schedule system, to - * apply state to a runable TVM schedule or provide the schedule Python code. + * \brief The Ansor computational graph and related program analyses. + * + * We convert a compute declaration described by `tvm.compute` (could be a single operator or a + * subgraph) to a ComputeDAG. It keeps the input/output tensors of the target compute declaration, + * a list of all related operations in topo order as well as a set of analyses over each operation + * stage (e.g. the total float operation count, consumer/producer relations of each operation + * stage, whether a operation stage should be tiled/compute inlined ...). These analyses can + * help the search policy to do some specific decisions during schedule search process. + * + * ComputeDAG is also responsible for the interaction between Ansor LoopState and TVM schedule + * (e.g. applying the LoopState transform steps to TVM schedule, providing LoopState with extra + * information get from TVM schedule ...). */ #ifndef TVM_ANSOR_COMPUTE_DAG_H_ @@ -74,9 +83,16 @@ class ComputeDAG : public ObjectRef { * \brief Apply transform steps to the init state of this DAG, and get the * equivalent `tvm::schedule`. * \param transform_steps Transform steps of the target state. + * \param stages A pointer to a `te::Stage` Array, default to be nullptr. + * Pass a valid pointer if these information needs to be used outside this function. + * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr. + * Pass a valid pointer if these information needs to be used outside this function. * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. */ - std::pair > ApplySteps(const Array& transform_steps) const; + std::pair > ApplySteps( + const Array& transform_steps, Array* stages = nullptr, + StageToAxesMap* stage_to_axes = nullptr) const; + /*! * \brief Print transform steps as equivalent python schedule API. * \param transform_steps Transform steps of the target state. @@ -84,19 +100,6 @@ class ComputeDAG : public ObjectRef { */ String PrintStepsAsPython(const Array& transform_steps) const; - /*! - * \brief Replay the transform steps and call ir_pass::InferBound to fill correct bound - * information. - * State api supports to define a split step with its split factor to be a blank placeholder, - * so sometimes we may get a State will incomplete iterator extent information. - * And another situation is after some steps (for exp. compute_at), it may be hard to track the - * extent change of all iterators. - * We perform infer bound using TVM schedule and fill the State with those informations. After - * applying this methods, the State is guaranteed to have complete interator extent information. - * \param transform_steps Transform steps of the target state. - * \return The State after inferbound. - */ - State ReplayAndInferBound(const Array& transform_steps) const; /*! * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. * \param state The target state. @@ -106,32 +109,12 @@ class ComputeDAG : public ObjectRef { /*! * \brief Fill the correct bound information for a list of given states. * Return the new states inplace. - * \param states A pointer to a State vector, States are updated inplace. + * \param states A pointer to a State Array, States are updated inplace. */ void InferBound(Array* states) const; TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); - - private: - /*! - * \brief Internal common parts for replaying steps. This is the key method to apply steps to - * TVM schedule. - * \param transform_steps Transform steps of the target state. - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. - * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. - */ - std::pair > ReplaySteps(const Array& transform_steps, - Array* stages, - StageToAxesMap* stage_to_axes) const; - - /*! - * \brief Internal common parts for inferring bound. - * \param pstate A pointer to StateNode, the target state will be updated with filled - * bound information. - */ - void InferBoundCommon(StateNode* pstate) const; }; } // namespace ansor diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index d02aee82f9ba..4f6d72d46694 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -271,7 +271,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { return new_it; } -void State::DoSteps(const Array& steps, const ComputeDAG& dag) { +void State::DoSteps(const ComputeDAG& dag) { + CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; + // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); double complete_rate = -1.0; @@ -279,9 +281,8 @@ void State::DoSteps(const Array& steps, const ComputeDAG& dag) { complete_rate = std::stod(complete_rate_str); } size_t ct = 0; - - for (const auto& step : steps) { - if (complete_rate >= 0 && ct++ > steps.size() * complete_rate) { + for (const auto& step : operator->()->transform_steps) { + if (complete_rate >= 0 && ct++ > operator->()->transform_steps.size() * complete_rate) { break; } if (auto ps = step.as()) { diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index c0288fbeccb6..e8ddcc480dcd 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -19,21 +19,29 @@ /*! * \file ansor/loop_state.h - * \brief The definition of the "state" in search. A state consists the current loop structure - * and the transform history to reach its current loop structure. - * To enable flexible manipulation of the loop structures, we implemented a lightweight loop - * structure IR (Intermediate Representation) based on the original TVM IR but specifically - * for schedule search. + * \brief The definition of the "state" in search. * - * We don't use the existing TVM IR but to extend a new Sketch IR on it is because: - * 1. We want fast incremental change to the loop structures; + * Each LoopState corresponds to a specific schedule for its target ComputeDAG. + * A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to + * construct it. + * The loop structure keeps a preview of how the schedule will finally look like after lowering the + * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations + * ...). During the schedule search process, the loop structure can provide search policy with + * necessary information on how to perform further operations with the current state. + * The transform history is a sequence of TransformStep which will finally be mapped to schedule + * primitives. The steps can also be used for serialization of a state. + * + * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. + * We don't use the existing TVM IR but to extend a new structure on it is because: + * 1. We want fast incremental change to the loop structures, search policy needs to get the + * immediate loop structures update rather than after TVM lowering; * 2. We want serializable transform history for replay, backtracking, and mutation; - * 3. We may create some macro schedule primitives that represent the combination of several - * TVM schedule primitives. + * 3. We may create some macro schedule primitives that represent the combination of several TVM + * schedule primitives. * - * After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. - * Because we share a lot common objects during search, the transformation is implemented in - * copy on write style. All objects are immutable, which is similar to TVM IR. + * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. + * Since we share a lot of common objects during search, the transformation is implemented in copy + * on write style. All objects are immutable, which is similar to TVM IR. */ #ifndef TVM_ANSOR_LOOP_STATE_H_ @@ -219,9 +227,9 @@ class Stage : public ObjectRef { }; /*! - * \brief A State in the search process. - * It consists of the current loop structure and the history steps to reach this State. - * Each State corresponds to a specific schedule for the target ComputeDAG. + * \brief A state in the search process. + * It consists of the current loop structure and a history of transformations used to construct it. + * Each State corresponds to a specific schedule for its target ComputeDAG. */ class StateNode : public Object { public: @@ -262,6 +270,22 @@ class State : public ObjectRef { */ explicit State(const Array& ops); + /*! + * \brief Print the state to a human readable string. + * \param delete_trivial_loop True for skipping the trivial loops. + * (undefined or extent == 1, default set to True) + * \return The human readable state structure. + */ + String ToStr(bool delete_trivial_loop = true) const; + + /*! + * \brief General do step functions with a runtime dynamic dispatcher. + * \param dag The target ComputeDAG. + */ + void DoSteps(const ComputeDAG& dag); + + /* Step APIs for State. */ + /*! * \brief Schedule primitive corresponds to te.reorder. * \param stage_id The index of the target stage. @@ -286,21 +310,6 @@ class State : public ObjectRef { */ Iterator fuse(int stage_id, const Array& iters); - /*! - * \brief General do step functions with a runtime dynamic dispatcher. - * \param steps The target transform steps. - * \param dag The target ComputeDAG. - */ - void DoSteps(const Array& steps, const ComputeDAG& dag); - - /*! - * \brief Print the state to a string. - * \param delete_trivial_loop True for skipping the trivial loops. - * (undefined or extent == 1, default set to True) - * \return The human readable state structure. - */ - String ToStr(bool delete_trivial_loop = true) const; - TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index f03d1a1f957b..003e25d95aff 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -37,8 +37,8 @@ TVM_REGISTER_NODE_TYPE(MeasureInputNode); TVM_REGISTER_NODE_TYPE(BuildResultNode); TVM_REGISTER_NODE_TYPE(MeasureResultNode); TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); -TVM_REGISTER_OBJECT_TYPE(RunnerNode); -TVM_REGISTER_OBJECT_TYPE(BuilderNode); +TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode); +TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); @@ -145,8 +145,9 @@ Array LocalRunnerNode::Run(const Array& inputs, } /********** ProgramMeasurer **********/ -ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, Array callbacks, - int verbose, int max_continous_error) { +ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, + Array callbacks, int verbose, + int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); node->runner = std::move(runner); @@ -306,13 +307,12 @@ TVM_REGISTER_GLOBAL("ansor.MeasureResult") return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); }); -TVM_REGISTER_GLOBAL("ansor.BuilderBuild") - .set_body_typed([](const Builder& builder, const Array& inputs, int verbose) { - return builder->Build(inputs, verbose); - }); +TVM_REGISTER_GLOBAL("ansor.ProgramBuilderBuild") + .set_body_typed([](const ProgramBuilder& builder, const Array& inputs, + int verbose) { return builder->Build(inputs, verbose); }); -TVM_REGISTER_GLOBAL("ansor.RunnerRun") - .set_body_typed([](const Runner& runner, const Array& inputs, +TVM_REGISTER_GLOBAL("ansor.ProgramRunnerRun") + .set_body_typed([](const ProgramRunner& runner, const Array& inputs, const Array& build_results, int verbose) { return runner->Run(inputs, build_results, verbose); }); diff --git a/src/ansor/measure.h b/src/ansor/measure.h index d552d688e12c..a3854e3d1bdd 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -220,8 +220,8 @@ class MeasureCallback : public ObjectRef { // Base class for builder and runner -/*! \brief Builder that builds the programs */ -class BuilderNode : public Object { +/*! \brief ProgramBuilder that builds the programs */ +class ProgramBuilderNode : public Object { public: /*! \brief The number of tasks to run in parallel */ int n_parallel; @@ -236,21 +236,21 @@ class BuilderNode : public Object { */ virtual Array Build(const Array& inputs, int verbose) = 0; - static constexpr const char* _type_key = "ansor.Builder"; - TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); + static constexpr const char* _type_key = "ansor.ProgramBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object); }; /*! - * \brief Managed reference to BuilderNode. - * \sa BuilderNode + * \brief Managed reference to ProgramBuilderNode. + * \sa ProgramBuilderNode */ -class Builder : public ObjectRef { +class ProgramBuilder : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Builder, ObjectRef, BuilderNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramBuilder, ObjectRef, ProgramBuilderNode); }; -/*! \brief Runner that runs the built programs and measure the time cost. */ -class RunnerNode : public Object { +/*! \brief ProgramRunner that runs the built programs and measure the time cost. */ +class ProgramRunnerNode : public Object { public: /*! \brief Timeout of a run. */ int timeout; @@ -265,23 +265,23 @@ class RunnerNode : public Object { virtual Array Run(const Array& inputs, const Array& build_results, int verbose) = 0; - static constexpr const char* _type_key = "ansor.Runner"; - TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); + static constexpr const char* _type_key = "ansor.ProgramRunner"; + TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object); }; /*! - * \brief Managed reference to RunnerNode. - * \sa RunnerNode + * \brief Managed reference to ProgramRunnerNode. + * \sa ProgramRunnerNode */ -class Runner : public ObjectRef { +class ProgramRunner : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Runner, ObjectRef, RunnerNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramRunner, ObjectRef, ProgramRunnerNode); }; // Implementation of various builders and runners /*! \brief LocalBuilder use local CPU cores to build programs in parallel */ -class LocalBuilderNode : public BuilderNode { +class LocalBuilderNode : public ProgramBuilderNode { public: /*! \brief Build function. */ String build_func; @@ -289,14 +289,14 @@ class LocalBuilderNode : public BuilderNode { Array Build(const Array& inputs, int verbose) final; static constexpr const char* _type_key = "ansor.LocalBuilder"; - TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode); }; /*! * \brief Managed reference to LocalBuilderNode. * \sa LocalBuilderNode */ -class LocalBuilder : public Builder { +class LocalBuilder : public ProgramBuilder { public: /*! * \brief The constructor. @@ -306,11 +306,11 @@ class LocalBuilder : public Builder { */ LocalBuilder(int timeout, int n_parallel, const String& build_func); - TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); + TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode); }; /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ -class LocalRunnerNode : public RunnerNode { +class LocalRunnerNode : public ProgramRunnerNode { public: /*! \brief Number of measure times. */ int number; @@ -325,14 +325,14 @@ class LocalRunnerNode : public RunnerNode { const Array& build_results, int verbose) final; static constexpr const char* _type_key = "ansor.LocalRunner"; - TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode); }; /*! * \brief Managed reference to LocalRunnerNode. * \sa LocalRunnerNode */ -class LocalRunner : public Runner { +class LocalRunner : public ProgramRunner { public: /*! * \brief The constructor. @@ -344,12 +344,12 @@ class LocalRunner : public Runner { */ LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, Runner, LocalRunnerNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, LocalRunnerNode); }; /*! * \brief Measurer that measures the time costs of tvm programs - * This class combines Builder and Runner, and provides a simpler API */ + * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */ class ProgramMeasurerNode : public Object { public: /*! \brief Measured programs counter. */ @@ -362,10 +362,10 @@ class ProgramMeasurerNode : public Object { std::unordered_map best_state; /*! \brief Workload key to best state's count index map. */ std::unordered_map best_ct; - /*! \brief The Builder to build each program. */ - Builder builder; - /*! \brief The Runner to measure each program. */ - Runner runner; + /*! \brief The ProgramBuilder to build each program. */ + ProgramBuilder builder; + /*! \brief The ProgramRunner to measure each program. */ + ProgramRunner runner; /*! \brief MeasureCallback to be called after each measure batch. */ Array callbacks; /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */ @@ -381,7 +381,7 @@ class ProgramMeasurerNode : public Object { * \param task The current SearchTask. * \param policy The current SearchPolicy. * \param inputs The target MeasureInputs. - * \param results A pointer to MeasureResult vector, this is used as output. + * \param results A pointer to a MeasureResult Array, this is used as output. * \param batch_size Number of programs to be measured in one batch. */ void Measure(const SearchTask& task, const SearchPolicy& policy, @@ -392,7 +392,7 @@ class ProgramMeasurerNode : public Object { * This API will not print the measure results to screen. * \param task The current SearchTask. * \param inputs The target MeasureInputs. - * \param results A pointer to MeasureResult vector, this is used as output. + * \param results A pointer to a MeasureResult Array, this is used as output. */ void SilentMeasure(const SearchTask& task, const Array& inputs, Array* results); @@ -412,14 +412,14 @@ class ProgramMeasurer : public ObjectRef { public: /*! * \brief The constructor. - * \param builder The Builder to build each program. - * \param runner The Runner to measure each program. + * \param builder The ProgramBuilder to build each program. + * \param runner The ProgramRunner to measure each program. * \param callbacks MeasureCallback to be called after each measure batch. * \param verbose Verbosity level. 0 for silent, 1 to output information during program measuring. * \param max_continous_error The number of max continuous error. */ - ProgramMeasurer(Builder builder, Runner runner, Array callbacks, int verbose, - int max_continous_error = -1); + ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Array callbacks, + int verbose, int max_continous_error = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); }; diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 7d782cb0eba2..2a1d9d3fcc9b 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -19,7 +19,7 @@ /*! * \file ansor/search_policy/search_policy.cc - * \brief The base class for search policy. + * \brief The base class for search policies. */ #include "search_policy.h" diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index aee93283aae6..96e27ccdb2eb 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -19,16 +19,17 @@ /*! * \file ansor/search_policy/search_policy.h - * \brief The base class for search policy, including the abstract defination of search policy and - * some other supporting structures. + * \brief The base class of search policies, including the abstract definition of search policy and + * other supporting data structures. * * The basic schedule search process for Ansor is design to be: * `Program sampling` -> `Performance Tuning`. * - * In `Program sampling`, we use some predefined or heuristic rules to generate several initial - * schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model - * and evolutionary search to seek for schedules with the best performance. Candidate schedules - * will be measured in the target hardware. + * In `Program sampling`, we use some predefined precise or heuristic rules to generate several + * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which + * uses cost model based evolutionary search to select schedules with the best performance. + * + * Candidate schedules are measured against the specific hardware target. * * \note Adding a new search policy. * In design, there's no need for users to implement their own search policy, our formal search @@ -72,7 +73,7 @@ class SearchCallbackNode : public Object { public: /*! * \brief Run the registered callback function. - * \param policy A pointer to SearchPolicyNode. + * \param policy A pointer to a SearchPolicyNode. */ virtual void Callback(SearchPolicyNode* policy) = 0; @@ -90,7 +91,7 @@ class SearchCallback : public ObjectRef { }; /*! - * \brief The base class for search policy. + * \brief The base class for search policies. */ class SearchPolicyNode : public Object { public: @@ -115,7 +116,7 @@ class SearchPolicyNode : public Object { * \param early_stopping Early stop if no better schedule is found. * \param num_measures_per_round Max measure batch in one search round. * \param verbose Verbose level. 0 for silent, 1 to output information during schedule search. - * \param measurer A ProgramMeasurer which packs Builder & Runner inside. + * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. * \param pre_search_callbacks SearchCallback to be called before schedule search. * \return The best state get. */ diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 9c92da5e387b..db78991010c0 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -34,7 +34,7 @@ namespace ansor { class HardwareParams; -/*! \brief Hardware related parameters */ +/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ class HardwareParamsNode : public Object { public: /*! \brief The number of cores. */ @@ -107,7 +107,9 @@ class HardwareParams : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); }; -/*! \brief Meta-info for a search task */ +/*! + * \brief The computation information and hardware parameters for a specific schedule search task. + */ class SearchTaskNode : public Object { public: /*! \brief The ComputeDAG for target compute declaration. */ diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index 600f4c2ce18e..3b8fe124e0bf 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -74,8 +74,8 @@ class LogReaderNode : public Object { /*! * \brief Read next line in the log file. - * \param inp A pointer to MeasureInputNode, this is used as output. - * \param res A pointer to MeasureResultNode, this is used as output. + * \param inp A pointer to a MeasureInputNode, this is used as output. + * \param res A pointer to a MeasureResultNode, this is used as output. * \return Whether the read is successful. */ bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); @@ -113,7 +113,7 @@ class LogReader : public ObjectRef { /*! * \brief Write measure records to an output stream. - * \param os A pointer to output stream. + * \param os A pointer to a output stream. * \param inputs The target MeasureInputs to be written. * \param results The target MeasureResults to be written. */ @@ -123,9 +123,9 @@ void WriteMeasureRecords(std::ostream* os, const Array& inputs, /*! * \brief Read one measure record from a string. * \param str The target record string to be extract. - * \param inp A pointer to MeasureInputNode, this is used as output. - * \param res A pointer to MeasureResultNode, this is used as output. - * \param log_version A pointer to log version string. + * \param inp A pointer to a MeasureInputNode, this is used as output. + * \param res A pointer to a MeasureResultNode, this is used as output. + * \param log_version A pointer to a log version string. */ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version); diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index f9271e73c3c8..f52daf593340 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -33,7 +33,7 @@ * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style - * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. + * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` * in `serialization.cc`. * 6. Add hash support in `struct hash<::tvm::ansor::Step>`. (search for this function in this file) @@ -55,7 +55,7 @@ namespace ansor { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; /*! - * \brief The base class for a transformation step. Each step has its corresponding tvm.te + * \brief The base class for transformation steps. Each step has its corresponding tvm.te * schedule primitives. */ class StepNode : public Object { @@ -87,15 +87,15 @@ class ReorderStepNode : public StepNode { /*! * \brief Apply the current state to tvm.schedule - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print step as equivalent python schedule API. - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -140,8 +140,8 @@ class SplitStepNode : public StepNode { /*! * \brief Apply the current state to tvm.schedule - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, @@ -149,8 +149,8 @@ class SplitStepNode : public StepNode { /*! * \brief Print step as equivalent python schedule API. - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -186,16 +186,16 @@ class FuseStepNode : public StepNode { /*! * \brief Apply the current state to tvm.schedule - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. */ tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print step as equivalent python schedule API. - * \param stages A pointer to `te::Stage` vector. - * \param stage_to_axes A pointer to StageToAxesMap. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; From cb2442f3e3a9e2d979e7e0c68c945d6d6ec0c8e6 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 3 Jul 2020 17:22:34 +0800 Subject: [PATCH 62/78] Lint fix --- python/tvm/ansor/compute_dag.py | 3 ++- python/tvm/ansor/loop_state.py | 6 ++++-- python/tvm/ansor/measure.py | 6 +++--- src/ansor/measure.h | 2 +- src/ansor/search_policy/search_policy.cc | 2 +- src/ansor/search_policy/search_policy.h | 17 +++++++++-------- src/ansor/search_task.cc | 1 + src/ansor/transform_step.h | 2 +- 8 files changed, 22 insertions(+), 17 deletions(-) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index fa836562db60..fd1a4eb9bdc0 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -59,7 +59,8 @@ def __init__(self, compute): if not isinstance(item, tvm.te.Tensor): raise ValueError("The input of ComputeDAG should be a list of Tensor") else: - raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor") + raise ValueError("Invalid compute: " + compute + + " . `ComputeDAG` expects a string or list of Tensor") self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) def get_init_state(self): diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index dbf3d678263b..8f379c816d9c 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -192,7 +192,8 @@ def _resolve_stage_id(self, stage_id): return self.stage_id_map[stage_id.op] if isinstance(stage_id, int): return stage_id - raise ValueError("Invalid stage_id: " + stage_id + ". Expect a int, Operation or Tensor") + raise ValueError("Invalid stage: " + stage_id + + " . Expect to be a int, Operation or Tensor") def _update_stage_id_map(self): if not self.stages_cache: @@ -210,7 +211,8 @@ def __getitem__(self, key): key = key.op if isinstance(key, Operation): return self.stages_cache[self.stage_id_map[key]] - raise ValueError("Invalid item: " + key + ". Expect a Operation or Tensor") + raise ValueError("Invalid item: " + key + + " . Expect to be a Operation or Tensor") def __str__(self): return str(self.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 7a3a3d5ec64a..85b1ff6891d3 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -47,7 +47,7 @@ @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): - """ Base class for measurement callback function. """ + """ The base class of measurement callback functions. """ @tvm._ffi.register_object("ansor.MeasureInput") @@ -117,7 +117,7 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.ProgramBuilder") class ProgramBuilder(Object): - """ Base class of ProgramBuilder. """ + """ The base class of ProgramBuilders. """ def build(self, measure_inputs, verbose=1): """ Build programs and return results. @@ -138,7 +138,7 @@ def build(self, measure_inputs, verbose=1): @tvm._ffi.register_object("ansor.ProgramRunner") class ProgramRunner(Object): - """ Base class of ProgramRunner """ + """ The base class of ProgramRunners. """ def run(self, measure_inputs, build_results, verbose=1): """ Run measurement and return results. diff --git a/src/ansor/measure.h b/src/ansor/measure.h index a3854e3d1bdd..c9978523a1cd 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -218,7 +218,7 @@ class MeasureCallback : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); }; -// Base class for builder and runner +// The base class of ProgramBuilders and ProgramRunners. /*! \brief ProgramBuilder that builds the programs */ class ProgramBuilderNode : public Object { diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 2a1d9d3fcc9b..70bed321f802 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -19,7 +19,7 @@ /*! * \file ansor/search_policy/search_policy.cc - * \brief The base class for search policies. + * \brief The base class of search policies. */ #include "search_policy.h" diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 96e27ccdb2eb..7d1f0a3349fa 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -33,16 +33,17 @@ * * \note Adding a new search policy. * In design, there's no need for users to implement their own search policy, our formal search - * policy(will be brought later) should be enough to cover auto schedule generation for different - * ops/subgraphs, and in the meantime, a custom rule mechanism will be provided to enable - * user-defined template search. (which should play a same role as the current AutoTVM template) + * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule + * mechanism will be provided to enable user-defined template search to serve the same functionality + * as the current AutoTVM template. + * * This guide is to help understand it better and incase some advanced users have special * requirements. * 1. The only funcion that must be implemented is Search(), the design principe for it is to be - * the entry of starting a schedule search and returns the best schedule get. - * 2. Imformations about the target ops/subgraphs can be acquired from SearchTask, this structure - * also contains HardwareParams which can be used to limit the search space. (For exp. limit the - * max vectorize size depending on the vector unit weight of a specific device) + * the entry of starting a schedule search process and returns the best schedule get. + * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. + * This structure also contains some information about the target device. (e.g. knowing the weight + * of the device vector unit, we can limit the max vectorize size during schedule generating) * 3. SearchCallback provides more flexibility to do extra affairs during the search process. * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get * during the search process. @@ -91,7 +92,7 @@ class SearchCallback : public ObjectRef { }; /*! - * \brief The base class for search policies. + * \brief The base class of search policies. */ class SearchPolicyNode : public Object { public: diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 7e6cb9d903d2..59161e73a61e 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -24,6 +24,7 @@ #include "search_task.h" +#include #include #include diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index f52daf593340..48dc9d3430f8 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -55,7 +55,7 @@ namespace ansor { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; /*! - * \brief The base class for transformation steps. Each step has its corresponding tvm.te + * \brief The base class of transformation steps. Each step has its corresponding tvm.te * schedule primitives. */ class StepNode : public Object { From b1ca20c7cf45bcb263b310639429a5e648345e8c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 3 Jul 2020 17:29:19 +0800 Subject: [PATCH 63/78] Update --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/workload_registry.py | 10 +++++----- src/ansor/search_task.cc | 1 - tests/python/unittest/test_ansor_common.py | 2 +- tests/python/unittest/test_ansor_search_policy.py | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 4fcf1008a2ea..04a10f2def5b 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -31,4 +31,4 @@ from .measure import MeasureInput, LocalBuilder, LocalRunner from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, append_measure_records_to_file -from .workload_registry import register_workload_by_func, make_workload_key_by_func +from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 0084102a8f75..03ca1d771682 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -38,7 +38,7 @@ WORKLOAD_FUNC_REGISTRY = {} -def register_workload_by_func(func): +def register_workload(func): """ Register a workload by generation function. The input function should take hashable and jsonable arguments @@ -51,7 +51,7 @@ def register_workload_by_func(func): Examples -------- - @ansor.register_workload_by_func + @ansor.register_workload def matmul(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') @@ -68,7 +68,7 @@ def matmul(N, M, K): return func -def make_workload_key_by_func(func, args): +def make_workload_key(func, args): """ make a workload key from function and arguments. Parameters @@ -93,7 +93,7 @@ def make_workload_key_by_func(func, args): if not func_name in WORKLOAD_FUNC_REGISTRY: raise ValueError("%s is not registered. " % func, - "Please register it with @ansor.register_workload_by_func") + "Please register it with @ansor.register_workload") args = serialize_args(args) @@ -118,7 +118,7 @@ def decode_workload_key_to_func_args(workload_key): workload = json.loads(workload_key) if not workload[0] in WORKLOAD_FUNC_REGISTRY: raise ValueError("%s is not registered. " % workload[0] + - "Please register it with @ansor.register_workload_by_func") + "Please register it with @ansor.register_workload") return workload[0], deserialize_args(workload[1:]) diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 59161e73a61e..7e6cb9d903d2 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -24,7 +24,6 @@ #include "search_task.h" -#include #include #include diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 773ca8e4f13e..9288d88b6270 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -21,7 +21,7 @@ import topi -@ansor.register_workload_by_func +@ansor.register_workload def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 0e4a70d840d0..f7990e272eb0 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -34,7 +34,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 - workload_key = ansor.make_workload_key_by_func(matmul_ansor_test, (N, N, N)) + workload_key = ansor.make_workload_key(matmul_ansor_test, (N, N, N)) dag = ansor.ComputeDAG(workload_key) target = tvm.target.create(target) task = ansor.SearchTask(dag, workload_key, target) From 8add768325052192015fe4679baa94594ee85b51 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 3 Jul 2020 17:41:40 +0800 Subject: [PATCH 64/78] Update --- src/ansor/search_task.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 7e6cb9d903d2..a65a8aea8f27 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -48,7 +48,7 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, const Target& target_host) { - if (target->target_name == "llvm") { + if (target->id->name == "llvm") { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 64, 64); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; From 78e531344cca654d2929a14e5afe0b98ebcb1e3f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 4 Jul 2020 10:29:31 +0800 Subject: [PATCH 65/78] Update --- python/tvm/ansor/auto_schedule.py | 14 +--- python/tvm/ansor/serialization.py | 2 +- src/ansor/compute_dag.h | 3 +- src/ansor/loop_state.cc | 24 +++--- src/ansor/loop_state.h | 4 +- src/ansor/search_task.cc | 13 +--- src/ansor/search_task.h | 14 +--- src/ansor/serialization.cc | 15 ++-- src/ansor/transform_step.cc | 27 +++---- src/ansor/transform_step.h | 74 ++----------------- src/ansor/utils.h | 4 +- tests/python/unittest/test_ansor_common.py | 16 ++++ .../unittest/test_ansor_search_policy.py | 14 ++-- 13 files changed, 79 insertions(+), 145 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 9243263c4a4b..9afc72c348c8 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -50,16 +50,10 @@ class HardwareParams(Object): The width of vector units in bytes. cache_line_bytes : int The size of cache line in bytes. - max_unroll_vec : int - The max length of an axis to be unrolled or vectorized. - max_innermost_split_factor : int - The max split factor for the innermost tile. """ - def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, - max_unroll_vec, max_innermost_split_factor): + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes): self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, - vector_unit_bytes, cache_line_bytes, - max_unroll_vec, max_innermost_split_factor) + vector_unit_bytes, cache_line_bytes) @tvm._ffi.register_object("ansor.SearchTask") @@ -69,9 +63,9 @@ class SearchTask(Object): Parameters ---------- dag : ComputeDAG - The ComputeDAG for the target compute declaration. + The ComputeDAG for the corresponding compute declaration. workload_key : str - The workload key for the target compute declaration. + The workload key for the corresponding compute declaration. target : tvm.target.Target The target device of this search task. target_host : Optional[tvm.target.Target] diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 8c8723ffbbaa..65534bb117a0 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -143,7 +143,7 @@ def best_measure_pair_in_file(filename, workload_key=None, target=None): continue if workload_key and inp.task.workload_key != workload_key: continue - if target and inp.task.target.target_name != target.target_name: + if target and inp.task.target.id.name != target.id.name: continue costs = [v.value for v in res.costs] diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 7275ac645e20..b32f6ea7c373 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -45,7 +45,7 @@ namespace tvm { namespace ansor { -/*! \brief Computation declaration graph. */ +/*! \brief The Ansor computational graph and related program analyses. */ class ComputeDAGNode : public Object { public: /*! \brief Input and output tensors. */ @@ -56,6 +56,7 @@ class ComputeDAGNode : public Object { double flop_ct; /*! \brief The initial state without any transform steps. */ State init_state; + // TODO(merrymercy): Add more analyses later. void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tensors", &tensors); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 4f6d72d46694..d792ffbd833f 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -114,14 +114,14 @@ void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " << "should be specified"; - Array after_ids; + Array after_ids; GetIndices(stage->iters, order, &after_ids); ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); DoReorderStep(step); } -Array State::split(int stage_id, const Iterator& it, const Array& lengths, +Array State::split(int stage_id, const Iterator& it, const Array& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = @@ -133,7 +133,7 @@ Array State::split(int stage_id, const Iterator& it, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; - Array indices; + Array indices; GetIndices(stage->iters, iters, &indices); FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); @@ -145,7 +145,7 @@ void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Array iters; for (auto x : step->after_ids) { - iters.push_back(stage->iters[x.as()->value]); + iters.push_back(stage->iters[x]); } StateNode* pstate = CopyOnWrite(); pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), @@ -153,7 +153,7 @@ void State::DoReorderStep(const ReorderStep& step) { } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, +Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; @@ -229,11 +229,10 @@ Iterator State::DoFuseStep(const FuseStep& step) { for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { - CHECK_EQ(step->fused_ids[i].as()->value, - step->fused_ids[i - 1].as()->value + 1); + CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); } - const Iterator& it = stage->iters[step->fused_ids[i].as()->value]; + const Iterator& it = stage->iters[step->fused_ids[i]]; new_name = new_name + it->name + "@"; if (it->range.defined() && new_extent.defined()) { @@ -258,10 +257,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { Iterator new_it = Iterator(new_name, range, new_iter_type, kNone); Array new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front().as()->value); + stage->iters.begin() + step->fused_ids.front()); new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), - stage->iters.begin() + step->fused_ids.back().as()->value + 1, + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); @@ -429,8 +427,8 @@ TVM_REGISTER_GLOBAL("ansor.StateReorder") }); TVM_REGISTER_GLOBAL("ansor.StateSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, bool inner_to_outer) { + .set_body_typed([](State state, int stage_id, const Iterator& it, const Array& lengths, + bool inner_to_outer) { const auto& res = state.split(stage_id, it, lengths, inner_to_outer); return Array{state, res}; }); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index e8ddcc480dcd..654896b880b4 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -300,7 +300,7 @@ class State : public ObjectRef { * \param inner_to_outer True for split from inner to outer & False for outer to inner. * \return The iterator results after split. */ - Array split(int stage_id, const Iterator& it, const Array& lengths, + Array split(int stage_id, const Iterator& it, const Array& lengths, bool inner_to_outer = true); /*! * \brief Schedule primitive corresponds to te.fuse. @@ -344,7 +344,7 @@ class State : public ObjectRef { * \param inner_to_outer The split direction. * \return The iterator results after split. */ - Array DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, + Array DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, bool inner_to_outer); }; diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index a65a8aea8f27..090f6c58f175 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -35,21 +35,18 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(HardwareParamsNode); TVM_REGISTER_NODE_TYPE(SearchTaskNode); -HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, - int max_unroll_vec, int max_innermost_split_factor) { +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes) { auto node = make_object(); node->num_cores = num_cores; node->vector_unit_bytes = vector_unit_bytes; node->cache_line_bytes = cache_line_bytes; - node->max_unroll_vec = max_unroll_vec; - node->max_innermost_split_factor = max_innermost_split_factor; data_ = std::move(node); } HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, const Target& target_host) { if (target->id->name == "llvm") { - return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 64, 64); + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } @@ -73,10 +70,8 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe } TVM_REGISTER_GLOBAL("ansor.HardwareParams") - .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, - int max_unroll_vec, int max_innermost_split_factor) { - return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, - max_innermost_split_factor); + .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) { + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes); }); TVM_REGISTER_GLOBAL("ansor.SearchTask") diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index db78991010c0..489dafa615b2 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -43,12 +43,9 @@ class HardwareParamsNode : public Object { int vector_unit_bytes; /*! \brief The size of cache line in bytes. */ int cache_line_bytes; - /*! \brief The max length of an axis to be unrolled or vectorized. */ - int max_unroll_vec; - /*! \brief The max split factor for the innermost tile. */ - int max_innermost_split_factor; - // Limitation params for GPU + // Some GPU related limitations + // Get from TVM device api /*! \brief The max shared memory per block. */ int max_shared_memory_per_block{INT32_MAX}; @@ -65,8 +62,6 @@ class HardwareParamsNode : public Object { v->Visit("num_cores", &num_cores); v->Visit("vector_unit_bytes", &vector_unit_bytes); v->Visit("cache_line_bytes", &cache_line_bytes); - v->Visit("max_unroll_vec", &max_unroll_vec); - v->Visit("max_innermost_split_factor", &max_innermost_split_factor); v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); v->Visit("max_registers_per_block", &max_registers_per_block); v->Visit("max_threads_per_block", &max_threads_per_block); @@ -97,11 +92,8 @@ class HardwareParams : public ObjectRef { * \param num_cores The number of cores. * \param vector_unit_bytes The width of vector units in bytes. * \param cache_line_bytes The size of cache line in bytes. - * \param max_unroll_vec The max length of an axis to be unrolled or vectorized. - * \param max_innermost_split_factor The max split factor for the innermost tile. */ - HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor); + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes); TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 0df937065e74..ec97ba74666d 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -43,12 +43,11 @@ namespace dmlc { namespace json { inline std::vector& IntArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::PrimExpr>& data) { + const ::tvm::Array<::tvm::Integer>& data) { out->clear(); for (const auto& x : data) { - auto pi = x.as<::tvm::tir::IntImmNode>(); - CHECK(pi != nullptr) << "Can only contain int values"; - out->push_back(pi->value); + CHECK(x.defined()); + out->push_back(x); } return *out; } @@ -119,7 +118,7 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - ::tvm::Array<::tvm::PrimExpr> after_ids; + ::tvm::Array<::tvm::Integer> after_ids; for (const auto& i : int_list) { after_ids.push_back(i); } @@ -140,9 +139,9 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&inner_to_outer); - ::tvm::Array<::tvm::PrimExpr> lengths; + ::tvm::Array<::tvm::Integer> lengths; for (const auto& i : int_list) { - lengths.push_back(::tvm::PrimExpr(i)); + lengths.push_back(i); } data->push_back(::tvm::ansor::SplitStep( stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); @@ -153,7 +152,7 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - ::tvm::Array<::tvm::PrimExpr> fused_ids; + ::tvm::Array<::tvm::Integer> fused_ids; for (const auto& i : int_list) { fused_ids.push_back(i); } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 2e1fbfb9cdbc..8769714d0513 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -37,7 +37,7 @@ namespace tvm { namespace ansor { /********** Reorder **********/ -ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { +ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); node->stage_id = stage_id; for (const auto& x : after_ids) { @@ -56,7 +56,7 @@ void ReorderStepNode::ApplyToSchedule(Array* stages, Array new_axes; new_axes.reserve(axes.size()); for (auto i : after_ids) { - new_axes.push_back(axes[i.as()->value]); + new_axes.push_back(axes[i]); } stage.reorder(new_axes); @@ -71,7 +71,7 @@ String ReorderStepNode::PrintAsPythonAPI(Array* stages, ss << "s[" << CleanName(stage->op->name) << "].reorder("; for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i].as()->value]->var->name_hint); + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); if (i != after_ids.size() - 1) { ss << ", "; } @@ -84,7 +84,7 @@ String ReorderStepNode::PrintAsPythonAPI(Array* stages, /********** Split **********/ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, - int stage_id, int iter_id, const Array& lengths, + int stage_id, int iter_id, const Array& lengths, bool inner_to_outer) { auto stage = (*stages)[stage_id]; const Array& axes = stage_to_axes->at(stage); @@ -127,7 +127,7 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st } String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, - int iter_id, const Array& lengths, bool inner_to_outer) { + int iter_id, const Array& lengths, bool inner_to_outer) { const auto& stage = (*stages)[stage_id]; auto to_split = stage_to_axes->at(stage)[iter_id]; const auto& func_name = CleanName(stage->op->name); @@ -156,13 +156,13 @@ String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_ return ss.str(); } -SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, +SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, bool inner_to_outer) { auto node = make_object(); node->stage_id = stage_id; // Extent can be a unreducible expression in some special cases if (extent->IsInstance()) { - node->extent = std::move(extent); + node->extent = std::move(tvm::Downcast(extent)); } node->iter_id = iter_id; node->lengths = lengths; @@ -181,7 +181,7 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, } /********** Fuse **********/ -FuseStep::FuseStep(int stage_id, const Array& fused_ids) { +FuseStep::FuseStep(int stage_id, const Array& fused_ids) { auto node = make_object(); node->stage_id = stage_id; for (const auto& x : fused_ids) { @@ -198,17 +198,15 @@ IterVar FuseStepNode::ApplyToSchedule(Array* stages, Array to_fuse; for (const auto& i : fused_ids) { - to_fuse.push_back(axes[i.as()->value]); + to_fuse.push_back(axes[i]); } IterVar fused_axis; stage.fuse(to_fuse, &fused_axis); Array new_axes; - new_axes.insert(new_axes.end(), axes.begin(), - axes.begin() + fused_ids.front().as()->value); + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back().as()->value + 1, - axes.end()); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); @@ -221,8 +219,7 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, std::stringstream to_fuse; for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName( - stage_to_axes->at(stage)[fused_ids[i].as()->value]->var->name_hint); + to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); if (i != fused_ids.size() - 1) { to_fuse << ", "; } diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 48dc9d3430f8..140b7b0539f1 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -36,8 +36,7 @@ * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` * in `serialization.cc`. - * 6. Add hash support in `struct hash<::tvm::ansor::Step>`. (search for this function in this file) - * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test. + * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test. */ #ifndef TVM_ANSOR_TRANSFORM_STEP_H_ @@ -83,7 +82,7 @@ class ReorderStepNode : public StepNode { * \brief The iterator ids after reorder. * This array should specify the order of all iterators. */ - Array after_ids; + Array after_ids; /*! * \brief Apply the current state to tvm.schedule @@ -115,7 +114,7 @@ class ReorderStep : public Step { * \param stage_id The index of the target stage. * \param after_ids The index of the iterators after reorder. */ - ReorderStep(int stage_id, const Array& after_ids); + ReorderStep(int stage_id, const Array& after_ids); TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; @@ -129,9 +128,9 @@ class SplitStepNode : public StepNode { /*! \brief The id of the iter to split. */ int iter_id; /*! \brief The extent length of the axis to split. */ - PrimExpr extent; + Integer extent; /*! \brief The split factors. */ - Array lengths; + Array lengths; /*! * \brief If true, the `lengths` denote the lengths of iterators * from inner level to outer level @@ -172,7 +171,7 @@ class SplitStep : public Step { * \param lengths The extent length of the axis to split. * \param inner_to_outer The split direction. */ - SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, + SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, bool inner_to_outer); TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); @@ -182,7 +181,7 @@ class SplitStep : public Step { class FuseStepNode : public StepNode { public: /*! \brief The ids of iterators to fuse. */ - Array fused_ids; + Array fused_ids; /*! * \brief Apply the current state to tvm.schedule @@ -215,7 +214,7 @@ class FuseStep : public Step { * \param stage_id The index of the target stage. * \param fused_ids The index of the target iterators to be fused. */ - FuseStep(int stage_id, const Array& fused_ids); + FuseStep(int stage_id, const Array& fused_ids); TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; @@ -223,61 +222,4 @@ class FuseStep : public Step { } // namespace ansor } // namespace tvm -// Hash and equal function for Step -namespace std { - -/*! \brief The hash function of each transform step. */ -template <> -struct hash<::tvm::ansor::Step> { - std::size_t operator()(const ::tvm::ansor::Step& step) const { - // clang-format off - if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { - size_t ret = ::dmlc::HashCombine(1, std::hash()(ps->stage_id)); - for (const auto& x : ps->after_ids) { - CHECK(x.defined()); - const auto& pint = x.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); - } - return ret; - } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { - size_t ret = ::dmlc::HashCombine(2, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - std::hash()(ps->inner_to_outer)))); - if (ps->extent.defined()) { - const auto& pint = ps->extent.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); - } else { - ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number - } - for (const auto& x : ps->lengths) { - if (x.defined()) { - const auto& pint = x.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); - } else { - ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number - } - } - return ret; - } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - size_t ret = ::dmlc::HashCombine(3, std::hash()(ps->stage_id)); - for (const auto& x : ps->fused_ids) { - CHECK(x.defined()); - const auto& pint = x.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); - } - return ret; - } else { - LOG(FATAL) << "Invalid step"; - } - return 0; - // clang-format on - } -}; -} // namespace std - #endif // TVM_ANSOR_TRANSFORM_STEP_H_ diff --git a/src/ansor/utils.h b/src/ansor/utils.h index c7fb7204ac69..54e434804549 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -66,11 +66,11 @@ namespace ansor { /********** Utilities for Array, std::string **********/ /*! \brief Get the first appearance index of elements in an Array */ template -inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { +inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); if (it != array.end()) { - indices->push_back(static_cast(it - array.begin())); + indices->push_back(it - array.begin()); } else { LOG(FATAL) << "Cannot find the item"; } diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 9288d88b6270..8c3895128849 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -17,6 +17,8 @@ """Common functions for ansor test cases""" +import threading + from tvm import te, ansor import topi @@ -67,3 +69,17 @@ def get_tiled_matmul(): return dag, s0 + +class PropagatingThread(threading.Thread): + def run(self): + self.exc = None + try: + self.ret = self._target(*self._args, **self._kwargs) + except BaseException as e: + self.exc = e + + def join(self): + super(PropagatingThread, self).join() + if self.exc: + raise self.exc + return self.ret diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index f7990e272eb0..8922fd722690 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -20,12 +20,11 @@ import random import numpy as np import tempfile -import threading import tvm from tvm import ansor -from test_ansor_common import matmul_ansor_test +from test_ansor_common import matmul_ansor_test, PropagatingThread def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', cost_model=None, num_measure_trials=2, params=None, @@ -44,11 +43,12 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.EmptyPolicy() # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) - tune_option = ansor.TuneOption(num_measure_trials=num_measure_trials, runner=runner, verbose=0, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=pre_search_callbacks) + tuning_options = ansor.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, + verbose=0, + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=pre_search_callbacks) sch, args = ansor.auto_schedule(task, target, search_policy=search_policy, - tune_option=tune_option) + tuning_options=tuning_options) inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) print("==== Python Code ====") @@ -78,7 +78,7 @@ def test_search_basic(): return # wrap the search in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool - t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) + t = PropagatingThread(target=search_common, kwargs={'seed': 944563397}) t.start() t.join() From 546abbe0c179c48850600e095bc136d24ddf0fe2 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 4 Jul 2020 22:25:12 +0800 Subject: [PATCH 66/78] Update --- python/tvm/ansor/__init__.py | 4 +- python/tvm/ansor/auto_schedule.py | 29 ++++----- python/tvm/ansor/compute_dag.py | 44 ++++++------- python/tvm/ansor/loop_state.py | 39 +++++------ python/tvm/ansor/measure.py | 27 +++++--- .../tvm/ansor/{serialization.py => record.py} | 14 ++-- python/tvm/ansor/utils.py | 12 ++-- python/tvm/ansor/workload_registry.py | 14 ++-- src/ansor/auto_schedule.cc | 5 +- src/ansor/auto_schedule.h | 12 ++-- src/ansor/compute_dag.cc | 4 +- src/ansor/compute_dag.h | 34 +++++----- src/ansor/loop_state.h | 64 ++++++++++--------- src/ansor/measure.h | 20 +++--- src/ansor/{serialization.cc => record.cc} | 4 +- src/ansor/{serialization.h => record.h} | 20 +++--- src/ansor/search_policy/empty_policy.h | 5 +- src/ansor/search_policy/search_policy.h | 17 +++-- src/ansor/search_task.h | 11 ++-- src/ansor/transform_step.cc | 2 +- src/ansor/transform_step.h | 23 +++---- src/ansor/utils.h | 6 +- tests/python/unittest/test_ansor_measure.py | 10 +-- 23 files changed, 219 insertions(+), 201 deletions(-) rename python/tvm/ansor/{serialization.py => record.py} (92%) rename src/ansor/{serialization.cc => record.cc} (99%) rename src/ansor/{serialization.h => record.h} (89%) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 04a10f2def5b..4baba87e1231 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -19,7 +19,7 @@ from . import compute_dag from . import measure -from . import serialization +from . import record from . import loop_state from . import utils from . import workload_registry @@ -29,6 +29,6 @@ from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ auto_schedule, EmptyPolicy from .measure import MeasureInput, LocalBuilder, LocalRunner -from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ +from .record import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, append_measure_records_to_file from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 9afc72c348c8..86c4adeabdb2 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -88,7 +88,7 @@ class SearchPolicy(Object): @tvm._ffi.register_object("ansor.EmptyPolicy") class EmptyPolicy(SearchPolicy): """ This is an example empty search policy which will always generate - the init state of target ComputeDAG. + the init state of ComputeDAG. """ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) @@ -101,19 +101,18 @@ class TuningOptions(Object): Parameters ---------- num_measure_trials: int = 0 - The number of total schedule measure trials. - Ansor takes `num_measure_trials` state for measuring in total, and finally gets the best - schedule among them. - With `num_measure_trials` == 0, Ansor will do the schedule search but don't involve - measurement, this can be used if we want to quickly get a runnable schedule without - performance tuning. + The number of measurement trials. + The search policy measures `num_measure_trials` schedules in total and returns the best one + among them. + With `num_measure_trials` == 0, the policy will do the schedule search but won't involve + measurement. + This can be used to get a runnable schedule quickly without auto-tuning. early_stopping: int = -1 - Stops early the tuning if no improvement get after n measurements. + Stop the tuning early if getting no improvement after n measurements. num_measures_per_round: int = 64 - The number of programs to be measured at each search round. - The whole schedule search process is designed to have several rounds to try a total - `num_measure_trials` schedules. - We have: `num_search_rounds` = `num_measure_trials` // `num_measures_per_round` + The number of schedules to be measured at each search round. + The whole schedule search process will try a total number of `num_measure_trials` in several + rounds. verbose: int = 1 Verbosity level. 0 for silent, 1 to output information during schedule search. builder: Union[ProgramBuilder, str] = 'local' @@ -121,7 +120,7 @@ class TuningOptions(Object): runner: Union[ProgramRunner, str] = 'local' ProgramRunner which runs the program and measures time costs. measure_callbacks: Optional[List[MeasureCallback]] - Callback functions called after each measure. + Callback functions called after each measurement. Candidates: - ansor.LogToFile pre_search_callbacks: Optional[List[SearchCallback]] @@ -164,7 +163,7 @@ def auto_schedule(task, target, target_host=None, search_policy='default', Parameters ---------- task : Union[SearchTask, str] - The target search task or workload key. + The SearchTask or workload key for the computation declaration. target : tvm.target.Target The target device of this schedule search. target_host : Optional[tvm.target.Target] @@ -178,7 +177,7 @@ def auto_schedule(task, target, target_host=None, search_policy='default', Returns ------- - A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build` + A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. """ if isinstance(search_policy, str): if search_policy == 'default': diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index fd1a4eb9bdc0..b72c8649133e 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -36,15 +36,14 @@ class ComputeDAG(Object): The Ansor computational graph and related program analyses. We convert a compute declaration described by `tvm.compute` (could be a single operator or a - subgraph) to a ComputeDAG. It keeps the input/output tensors of the target compute declaration, - a list of all related operations in topo order as well as a set of analyses over each operation - stage (e.g. the total float operation count, consumer/producer relations of each operation - stage, whether a operation stage should be tiled/compute inlined ...). These analyses can - help the search policy to do some specific decisions during schedule search process. - - ComputeDAG is also responsible for the interaction between Ansor LoopState and TVM schedule - (e.g. applying the LoopState transform steps to TVM schedule, providing LoopState with extra - information get from TVM schedule ...). + subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, + a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the + total float operation count, consumer/producer relations of each operation stage, whether an + operation stage should be tiled/compute inlined ...). These analyses can help the search policy + to make decisions during search process. + ComputeDAG is also responsible for the interaction between Ansor `LoopState` and TVM schedule + (e.g. applying the `LoopState` transform steps to TVM schedule, providing `LoopState` with extra + information got from TVM schedule ...). Parameters ---------- @@ -75,16 +74,16 @@ def get_init_state(self): def apply_steps_from_state(self, state): """ - Apply the history transform steps of a State to TVM schedule. + Apply the history transform steps from a State to get a TVM schedule. Parameters ---------- state : Union[State, StateObject] - The target state to be applied to TVM schedule. + The state from which we get transform steps. Returns ------- - A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build` + A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) @@ -93,10 +92,13 @@ def print_python_code_from_state(self, state): """ Print transform steps in the history of a State as TVM's python schedule primitive. + This can be used for debugging or to apply the schedule on a former TVM version without + Ansor support. + Parameters ---------- state : Union[State, StateObject] - The target state to be applied to TVM schedule. + The state from which we get transform steps. Returns ------- @@ -108,21 +110,19 @@ def print_python_code_from_state(self, state): def infer_bound_from_state(self, state): """ - Infer and fill the bound of all iterators of a state using TVM schedule. - - State api supports to define a split step with its split factor to be a blank placeholder, - so sometimes we may get a State will incomplete iterator extent information. - And another situation is after some steps (for exp. compute_at), it may be hard to track - the extent change of all iterators. + Infer and fill the bound of all iterators of a state. - We perform infer bound using TVM schedule and fill the State with those information. After - applying this methods, the State is guaranteed to have complete interator extent + The states can lose complete bound information after some transform steps + (e.g., compute_at). + We can call this function to infer and fill all the bound information. + This function calls TVM InferBound pass internally to get the bound. + The returned state of this function is guaranteed to have complete iterator extent information. Parameters ---------- state : Union[State, StateObject] - The target state to be applied to TVM schedule. + The state from which we get transform steps. Returns ------- diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 8f379c816d9c..61ac371762a4 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -19,20 +19,20 @@ """ The definition of the "state" in search. -Each LoopState corresponds to a specific schedule for its target ComputeDAG. -A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to +Each LoopState corresponds to a schedule for its ComputeDAG. +A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to construct the loop structure. The loop structure keeps a preview of how the schedule will finally look like after lowering the current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...). During the schedule search process, the loop structure can provide search policy with necessary -information on how to perform further operations with the current state. -The transform history is a sequence of TransformStep which will finally be mapped to schedule -primitives. The steps can also be used for serialization of a state. +information on how to manipulate the current state. +The transform history is a sequence of `TransformStep` which will finally be mapped to TVM schedule +primitives. The steps can also be used for the serialization of a state. The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. We don't use the existing TVM IR but to extend a new structure on it is because: -1. We want fast incremental change to the loop structures, search policy needs to get the immediate -loop structures update rather than after TVM lowering; +1. We want fast incremental change to the loop structures. The search policy needs to get the +immediate loop structures update rather than after TVM lowering; 2. We want serializable transform history for replay, backtracking, and mutation; 3. We may create some macro schedule primitives that represent the combination of several TVM schedule primitives. @@ -55,7 +55,7 @@ class Iterator(Object): @tvm._ffi.register_object("ansor.Stage") class Stage(Object): - """A stage in the compute declaration. Similar to tvm.te.schedule.Stage""" + """ A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """ @tvm._ffi.register_object("ansor.State") @@ -68,16 +68,16 @@ def __eq__(self, other): class State: """ A state in the search process. It consists of the current loop structure - and a history of transformations used to construct it. + and a list of transformation steps used to construct it. - Each State corresponds to a specific schedule for its target ComputeDAG. + Each State corresponds to a specific schedule for its ComputeDAG. Parameters ---------- state_object : StateObject - The target StateObject, corresponding to C++ internal State object. + The StateObject corresponding to C++ internal State object. dag : ComputeDAG - The original target ComputeDAG of this State. + The original ComputeDAG of this State. Notes ----- @@ -119,10 +119,10 @@ def reorder(self, stage, order): Parameters ---------- stage : Union[int, Operation, Tensor] - The target Stage to be reordered, can be a Stage order index, Stage operation or stage + The Stage to be reordered, can be a Stage order index, Stage operation or stage output tensor. order : List[Iterator] - Iterators in the expected order + Iterators in the expected order. """ stage_id = self._resolve_stage_id(stage) @@ -132,15 +132,18 @@ def reorder(self, stage, order): def split(self, stage, iterator, lengths, inner_to_outer=True): """ Schedule primitive corresponds to te.split. + This API supports multiple split factors. (e.g. with 2 split factors, the original iterator + will be split to 3 parts, use `inner_to_outer` to control the split order) + Parameters ---------- stage : Union[int, Operation, Tensor] - The target Stage to be split, can be a Stage order index, Stage operation or stage + The Stage to be split, can be a Stage order index, Stage operation or stage output tensor. iterator : Iterator - The iterator to split + The iterator to be split. lengths: List[int] - The split factors + The multiple split factors. Can be None to be filled by search policy. inner_to_outer: bool = True True to use `factor` to split from inner to outer, False to use `nparts` to split from outer to inner @@ -163,7 +166,7 @@ def fuse(self, stage, iters): Parameters ---------- stage : Union[int, Operation, Tensor] - The target Stage to be fused, can be a Stage order index, Stage operation or stage + The Stage to be fused, can be a Stage order index, Stage operation or stage output tensor. iters : List[Iterator] The iterators to be fused diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 85b1ff6891d3..6c3410c87076 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -42,7 +42,8 @@ # The maximum length of error message MAX_ERROR_MSG_LEN = 512 -# Global variables used in build function +# We use fork and a global variable to copy arguments between processings. +# This can avoid expensive serialization of TVM IR when using multiprocessing.Pool GLOBAL_BUILD_ARGUMENTS = None @tvm._ffi.register_object("ansor.MeasureCallback") @@ -57,9 +58,9 @@ class MeasureInput(Object): Parameters ---------- task : SearchTask - The target SearchTask. + The SearchTask of this measure. state : State - The current State to be measured. + The State to be measured. """ def __init__(self, task, state): self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) @@ -190,11 +191,21 @@ class LocalRunner(ProgramRunner): timeout : int = 10 The timeout limit for each run. number : int = 3 - Number of measure times. + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. repeat : int = 1 - Number of repeat times in each measure. + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. min_repeat_ms : int = 0 - The minimum duration of one repeat in milliseconds. + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. cooldown_interval : float = 0.0 The cool down interval between two measurements. """ @@ -235,7 +246,7 @@ def make_error_msg(): def local_build_worker(index): """ Local builder function. """ - # We use fork to copy arguments from a global variable. + # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool if not GLOBAL_BUILD_ARGUMENTS: raise ValueError("GLOBAL_BUILD_ARGUMENTS not found") @@ -302,7 +313,7 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): """ Local builder build function. """ - # We use fork to copy arguments from a global variable. + # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global GLOBAL_BUILD_ARGUMENTS GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose) diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/record.py similarity index 92% rename from python/tvm/ansor/serialization.py rename to python/tvm/ansor/record.py index 65534bb117a0..6f94105494aa 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/record.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Serialization and other I/O support for tuning logs (measurement records)""" +""" Serialization and other I/O support for tuning logs (measurement records). """ import numpy as np @@ -98,16 +98,16 @@ def load_from_file(filename): def append_measure_records_to_file(filename, inputs, results): """ - Aappend measure records to file. + Append measure records to file. Parameters ---------- filename : str File name to write log to. inputs: List[MeasureInputs] - The target MeasureInputs to be written. + The MeasureInputs to be written. results: List[MeasureResults] - The target MeasureResults to be written. + The MeasureResults to be written. """ _ffi_api.AppendMeasureRecordsToFile(filename, inputs, results) @@ -119,10 +119,10 @@ def best_measure_pair_in_file(filename, workload_key=None, target=None): ---------- filename : str File name to load log from. - workload_key : Optional[str] = None - The workload key of the target compute declaration. + workload_key : Optional[str] + The workload key of the compute declaration. With `None`, this retuns the best measure pair of all workloads. - target : Optional[tvm.target.Target] = None + target : Optional[tvm.target.Target] The target device. With `None`, this retuns the best measure pair of all target devices. diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 9dbcd81f36e7..c698812e54c7 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -40,7 +40,7 @@ def get_func_name(func): Parameters ---------- func: Function - The target function. + The input function. Returns ------- @@ -55,7 +55,7 @@ def get_const_int(exp): Parameters ---------- - exp : tvm.Expr or int + exp : Union[tvm.tir.expr, int] The input expression. Returns @@ -65,10 +65,10 @@ def get_const_int(exp): """ if isinstance(exp, int): return exp - if not isinstance(exp, (expr.IntImm)): + if not isinstance(exp, expr.IntImm): opt = Sequential([Simplify()]) exp = opt(exp) - if not isinstance(exp, (expr.IntImm)): + if not isinstance(exp, expr.IntImm): raise ValueError("Expect value to be constant int") return exp.value @@ -78,12 +78,12 @@ def get_const_tuple(in_tuple): Parameters ---------- - in_tuple : tuple of Expr + in_tuple : Tuple[tvm.tir.expr] The input. Returns ------- - out_tuple : tuple of int + out_tuple : Tuple[int] The output. """ return tuple(get_const_int(x) for x in in_tuple) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 03ca1d771682..5726ae3a7507 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -18,7 +18,7 @@ """ Workload registration and serialization. -We use a json string to represent a workload (a compute dag). +We use a json string to represent a workload (a computation graph). The format of the string is `[func_name, [args...]]`. The dag should be the return value of this `func_name(*args)`. @@ -47,7 +47,7 @@ def register_workload(func): Parameters ---------- func : Function - The target function that returns the compute declaration Tensors. + The generation function that returns the compute declaration Tensors. Examples -------- @@ -74,15 +74,15 @@ def make_workload_key(func, args): Parameters ---------- func : Union[Function, str] - The target function that returns the compute declaration Tensors. + The function that returns the compute declaration Tensors. Can be the a function or the function name. args : Args - The args of the target function. + The args of the function. Returns ------- workload_key : Str - The workload key of the target function. + The workload key of the function. """ if callable(func): func_name = func.__name__ @@ -106,7 +106,7 @@ def decode_workload_key_to_func_args(workload_key): Parameters ---------- workload_key : str - The target workload key. + The input workload key. Returns ------- @@ -131,7 +131,7 @@ def workload_key_to_tensors(workload_key): Parameters ---------- workload_key : str - The target workload key. + The input workload key. Returns ------- diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index dfaff797a179..ee51fc9a0210 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -47,9 +47,8 @@ TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num data_ = std::move(node); } -std::pair > AutoSchedule(SearchTask task, - SearchPolicy search_policy, - TuningOptions tuning_options) { +std::pair> AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuningOptions tuning_options) { // Create a ProgramMeasurer to handle the schedule build and performance measure ProgramMeasurer measurer = ProgramMeasurer(tuning_options->builder, tuning_options->runner, diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 8127990ca2ec..c493c0dfc47c 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -97,15 +97,15 @@ class TuningOptions : public ObjectRef { }; /*! - * \brief Auto schedule search for a given compute declaration, by SearchTask. - * \param task The target search task. + * \brief Auto schedule search for a given compute declaration. + * \param task The search task of the compute declaration. * \param search_policy The search policy to be used for schedule search. * \param tuning_options Tuning and measurement options. - * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`. + * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or + * `tvm.build`. */ -std::pair > AutoSchedule(SearchTask task, - SearchPolicy search_policy, - TuningOptions tuning_options); +std::pair> AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuningOptions tuning_options); } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 80843c420044..b99b609be42b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -50,7 +50,7 @@ TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Results are stored in ops void TopoSortOps(const Array& tensors, Array* ops) { std::unordered_map degree; - std::unordered_map > edge_set; + std::unordered_map> edge_set; std::unordered_map priority; std::unordered_set visited; @@ -240,7 +240,7 @@ void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { } } -std::pair > ComputeDAG::ApplySteps( +std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes) const { // Temporal object to be used if the input pointer is nullptr diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index b32f6ea7c373..66459bb70864 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -22,15 +22,14 @@ * \brief The Ansor computational graph and related program analyses. * * We convert a compute declaration described by `tvm.compute` (could be a single operator or a - * subgraph) to a ComputeDAG. It keeps the input/output tensors of the target compute declaration, - * a list of all related operations in topo order as well as a set of analyses over each operation - * stage (e.g. the total float operation count, consumer/producer relations of each operation - * stage, whether a operation stage should be tiled/compute inlined ...). These analyses can - * help the search policy to do some specific decisions during schedule search process. - * - * ComputeDAG is also responsible for the interaction between Ansor LoopState and TVM schedule - * (e.g. applying the LoopState transform steps to TVM schedule, providing LoopState with extra - * information get from TVM schedule ...). + * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, + * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the + * total float operation count, consumer/producer relations of each operation stage, whether an + * operation stage should be tiled/compute inlined ...). These analyses can help the search policy + * to make decisions during search process. + * ComputeDAG is also responsible for the interaction between Ansor `LoopState` and TVM schedule + * (e.g. applying the `LoopState` transform steps to TVM schedule, providing `LoopState` with extra + * information got from TVM schedule ...). */ #ifndef TVM_ANSOR_COMPUTE_DAG_H_ @@ -81,29 +80,30 @@ class ComputeDAG : public ObjectRef { explicit ComputeDAG(Array tensors); /*! - * \brief Apply transform steps to the init state of this DAG, and get the - * equivalent `tvm::schedule`. - * \param transform_steps Transform steps of the target state. + * \brief Apply the history transform steps from a State to get a TVM schedule. + * \param transform_steps Transform steps of a state. * \param stages A pointer to a `te::Stage` Array, default to be nullptr. * Pass a valid pointer if these information needs to be used outside this function. * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr. * Pass a valid pointer if these information needs to be used outside this function. - * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`. + * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. */ - std::pair > ApplySteps( + std::pair> ApplySteps( const Array& transform_steps, Array* stages = nullptr, StageToAxesMap* stage_to_axes = nullptr) const; /*! * \brief Print transform steps as equivalent python schedule API. - * \param transform_steps Transform steps of the target state. - * \return Python schedule code. + * This can be used for debugging or to apply the schedule on a former TVM version without Ansor + * support. + * \param transform_steps Transform steps of a state. + * \return The Python schedule code. */ String PrintStepsAsPython(const Array& transform_steps) const; /*! * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. - * \param state The target state. + * \param state The state to. * \return The State after inferbound. */ State InferBound(const State& state) const; diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 654896b880b4..f91a5dbb45bb 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -21,27 +21,28 @@ * \file ansor/loop_state.h * \brief The definition of the "state" in search. * - * Each LoopState corresponds to a specific schedule for its target ComputeDAG. - * A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to - * construct it. + * Each LoopState corresponds to a schedule for its ComputeDAG. + * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to + * construct the loop structure. * The loop structure keeps a preview of how the schedule will finally look like after lowering the * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations - * ...). During the schedule search process, the loop structure can provide search policy with - * necessary information on how to perform further operations with the current state. - * The transform history is a sequence of TransformStep which will finally be mapped to schedule - * primitives. The steps can also be used for serialization of a state. + * ...). + * During the schedule search process, the loop structure can provide search policy with necessary + * information on how to manipulate the current state. + * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM + * schedule primitives. The steps can also be used for the serialization of a state. * * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. * We don't use the existing TVM IR but to extend a new structure on it is because: - * 1. We want fast incremental change to the loop structures, search policy needs to get the + * 1. We want fast incremental change to the loop structures. The search policy needs to get the * immediate loop structures update rather than after TVM lowering; * 2. We want serializable transform history for replay, backtracking, and mutation; - * 3. We may create some macro schedule primitives that represent the combination of several TVM - * schedule primitives. + * 3. We may create some macro schedule primitives that represent the combination of several + * TVM schedule primitives. * * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. - * Since we share a lot of common objects during search, the transformation is implemented in copy - * on write style. All objects are immutable, which is similar to TVM IR. + * Since we share a lot of common objects during search, the transformation is implemented in + * copy on write style. All objects are immutable, which is similar to TVM IR. */ #ifndef TVM_ANSOR_LOOP_STATE_H_ @@ -122,7 +123,7 @@ class IteratorNode : public Object { public: /*! \brief The name of this iterator. */ String name; - /*! \brief The target range of this iterator. */ + /*! \brief The range of this iterator. */ Range range; /*! \brief The iterator type of this iterator. */ IteratorType iter_type; @@ -147,7 +148,7 @@ class Iterator : public ObjectRef { /*! * \brief The constructor. * \param name The name of this iterator. - * \param range The target range of this iterator. + * \param range The range of this iterator. * \param iter_type The iterator type of this iterator. * \param annotation The annotation type of this iterator. */ @@ -228,8 +229,9 @@ class Stage : public ObjectRef { /*! * \brief A state in the search process. - * It consists of the current loop structure and a history of transformations used to construct it. - * Each State corresponds to a specific schedule for its target ComputeDAG. + * It consists of the current loop structure and a list of transformation steps used to construct + * it. + * Each State corresponds to a specific schedule for its ComputeDAG. */ class StateNode : public Object { public: @@ -252,7 +254,8 @@ class StateNode : public Object { private: /*! * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the - * stage structure of the ComputeDAG, for exp. CacheReadStep/CacheWriteStep(Will be added later). + * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added + * later). * The default value is an empty ObjectRef. (means no modification to the original DAG) */ ObjectRef current_compute_dag; @@ -279,8 +282,11 @@ class State : public ObjectRef { String ToStr(bool delete_trivial_loop = true) const; /*! - * \brief General do step functions with a runtime dynamic dispatcher. - * \param dag The target ComputeDAG. + * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the + * transform steps with the initial state. + * \param dag The original ComputeDAG of this state. + * \note This is different from the class member `current_compute_dag`, for some transform step + * may change the op stage structure of the ComputeDAG. */ void DoSteps(const ComputeDAG& dag); @@ -288,15 +294,15 @@ class State : public ObjectRef { /*! * \brief Schedule primitive corresponds to te.reorder. - * \param stage_id The index of the target stage. - * \param order The target iterator order. + * \param stage_id The index of the stage to be reordered. + * \param order The expected iterator order. */ void reorder(int stage_id, const Array& order); /*! * \brief Schedule primitive corresponds to te.split. - * \param stage_id The index of the target stage. - * \param it The target iterator. - * \param lengths The target split factors. Can be None to be filled by search policy. + * \param stage_id The index of the stage to be split. + * \param it The iterator the be split. + * \param lengths The multiple split factors. Can be None to be filled by search policy. * \param inner_to_outer True for split from inner to outer & False for outer to inner. * \return The iterator results after split. */ @@ -304,8 +310,8 @@ class State : public ObjectRef { bool inner_to_outer = true); /*! * \brief Schedule primitive corresponds to te.fuse. - * \param stage_id The index of the target stage. - * \param iters The target iterators to be fused. + * \param stage_id The index of the stage to be fused. + * \param iters The iterators to be fused. * \return The iterator result after fuse. */ Iterator fuse(int stage_id, const Array& iters); @@ -338,9 +344,9 @@ class State : public ObjectRef { /*! * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later). - * \param stage_id The index of the target stage. - * \param iter_id The index of the target iterator. - * \param lengths The target split factors. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param lengths The multiple split factors. * \param inner_to_outer The split direction. * \return The iterator results after split. */ diff --git a/src/ansor/measure.h b/src/ansor/measure.h index c9978523a1cd..3442e8b3e18f 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -20,12 +20,13 @@ /*! * \file ansor/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. - * MeasureInput -> BuildeResult -> MeasureResult + * The flow of data structures is MeasureInput -> BuildeResult -> MeasureResult. */ #ifndef TVM_ANSOR_MEASURE_H_ #define TVM_ANSOR_MEASURE_H_ +#include #include #include @@ -91,8 +92,8 @@ class MeasureInput : public ObjectRef { public: /*! * \brief The constructor. - * \param task The target SearchTeask. - * \param state The target State. + * \param task The SearchTeask of this measure. + * \param state The State to be measured. */ MeasureInput(SearchTask task, State state); @@ -335,7 +336,8 @@ class LocalRunnerNode : public ProgramRunnerNode { class LocalRunner : public ProgramRunner { public: /*! - * \brief The constructor. + * \brief The constructor. See the corresponding class in python/tvm/ansor/measure.py for more + * detailed parameter explaination. * \param timeout The timeout limit for each run. * \param number Number of measure times. * \param repeat Number of repeat times in each measure. @@ -357,11 +359,11 @@ class ProgramMeasurerNode : public Object { /*! \brief Continuous error counter. */ int error_ct; /*! \brief Workload key to best flops map. */ - std::unordered_map best_flops; + std::unordered_map best_flops; /*! \brief Workload key to best state map. */ - std::unordered_map best_state; + std::unordered_map best_state; /*! \brief Workload key to best state's count index map. */ - std::unordered_map best_ct; + std::unordered_map best_ct; /*! \brief The ProgramBuilder to build each program. */ ProgramBuilder builder; /*! \brief The ProgramRunner to measure each program. */ @@ -380,7 +382,7 @@ class ProgramMeasurerNode : public Object { * \brief Do measurement. * \param task The current SearchTask. * \param policy The current SearchPolicy. - * \param inputs The target MeasureInputs. + * \param inputs The MeasureInputs. * \param results A pointer to a MeasureResult Array, this is used as output. * \param batch_size Number of programs to be measured in one batch. */ @@ -391,7 +393,7 @@ class ProgramMeasurerNode : public Object { * \brief Do measurement silently. * This API will not print the measure results to screen. * \param task The current SearchTask. - * \param inputs The target MeasureInputs. + * \param inputs The MeasureInputs. * \param results A pointer to a MeasureResult Array, this is used as output. */ void SilentMeasure(const SearchTask& task, const Array& inputs, diff --git a/src/ansor/serialization.cc b/src/ansor/record.cc similarity index 99% rename from src/ansor/serialization.cc rename to src/ansor/record.cc index ec97ba74666d..b99a67f4e64c 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/record.cc @@ -18,11 +18,11 @@ */ /*! - * \file ansor/serialization.cc + * \file ansor/record.cc * \brief Json serialization format for dumping and loading tuning records. */ -#include "serialization.h" +#include "record.h" #include #include diff --git a/src/ansor/serialization.h b/src/ansor/record.h similarity index 89% rename from src/ansor/serialization.h rename to src/ansor/record.h index 3b8fe124e0bf..0e26b6bdaf1e 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/record.h @@ -18,12 +18,12 @@ */ /*! - * \file ansor/serialization.h + * \file ansor/record.h * \brief Json serialization format for dumping and loading tuning records. */ -#ifndef TVM_ANSOR_SERIALIZATION_H_ -#define TVM_ANSOR_SERIALIZATION_H_ +#ifndef TVM_ANSOR_RECORD_H_ +#define TVM_ANSOR_RECORD_H_ #include #include @@ -62,7 +62,7 @@ class LogToFile : public MeasureCallback { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); }; -/*! \brief Log reader to load step logs from a target file.*/ +/*! \brief Log reader to load step logs from a file.*/ class LogReaderNode : public Object { public: /*! \brief File name for this reader to load log from. */ @@ -85,8 +85,8 @@ class LogReaderNode : public Object { * \param skip_size Skip the first n lines. * \return The MeasureInputs and MeasureResults loaded from the log file. */ - std::pair, Array > ReadLines(int max_size = -1, - int skip_size = 0); + std::pair, Array> ReadLines(int max_size = -1, + int skip_size = 0); static constexpr const char* _type_key = "ansor.LogReader"; TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); @@ -114,15 +114,15 @@ class LogReader : public ObjectRef { /*! * \brief Write measure records to an output stream. * \param os A pointer to a output stream. - * \param inputs The target MeasureInputs to be written. - * \param results The target MeasureResults to be written. + * \param inputs The MeasureInputs to be written. + * \param results The MeasureResults to be written. */ void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results); /*! * \brief Read one measure record from a string. - * \param str The target record string to be extract. + * \param str The record string to be extract. * \param inp A pointer to a MeasureInputNode, this is used as output. * \param res A pointer to a MeasureResultNode, this is used as output. * \param log_version A pointer to a log version string. @@ -133,4 +133,4 @@ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureRes } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_SERIALIZATION_H_ +#endif // TVM_ANSOR_RECORD_H_ diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index a8fd4fb424e9..58209b666048 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -32,9 +32,8 @@ namespace tvm { namespace ansor { /*! - * \brief The EmptyPolicy will always generates the init state of a ComputeDAG. - * This is an brief example of search policy, while can show the design of search policy, - * the formal search policy will continue to follow it. + * \brief A brief example of the search policy which always returns the initial naive schedule + * (state), the formal search policy will continue to follow its design. * The key implementation for this structure is `Search()`, check `empty_policy.cc` for more * details. */ diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 7d1f0a3349fa..1ae601896ce0 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -37,15 +37,14 @@ * mechanism will be provided to enable user-defined template search to serve the same functionality * as the current AutoTVM template. * - * This guide is to help understand it better and incase some advanced users have special - * requirements. - * 1. The only funcion that must be implemented is Search(), the design principe for it is to be - * the entry of starting a schedule search process and returns the best schedule get. + * This guide is for advanced uses who have special requirements. + * 1. The only function that must be implemented is Search(), which takes a task as input and + * returns the best states found. * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. - * This structure also contains some information about the target device. (e.g. knowing the weight - * of the device vector unit, we can limit the max vectorize size during schedule generating) - * 3. SearchCallback provides more flexibility to do extra affairs during the search process. - * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get + * This structure also contains some information about the target device. (e.g. knowing the width + * of the device vector unit, we can limit the max vectorize size during schedule search) + * 3. SearchCallback provides more flexibility to do extra affairs before/after the search process. + * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states got * during the search process. */ @@ -112,7 +111,7 @@ class SearchPolicyNode : public Object { /*! * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state * get during the search process. - * \param task The target search task. + * \param task The SearchTask or workload key for the computation declaration * \param num_measure_trials Total schedules to be tried during this search. * \param early_stopping Early stop if no better schedule is found. * \param num_measures_per_round Max measure batch in one search round. diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 489dafa615b2..ba418049e655 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -44,8 +44,7 @@ class HardwareParamsNode : public Object { /*! \brief The size of cache line in bytes. */ int cache_line_bytes; - // Some GPU related limitations - // Get from TVM device api + // GPU related parameters got from device query API /*! \brief The max shared memory per block. */ int max_shared_memory_per_block{INT32_MAX}; @@ -104,9 +103,9 @@ class HardwareParams : public ObjectRef { */ class SearchTaskNode : public Object { public: - /*! \brief The ComputeDAG for target compute declaration. */ + /*! \brief The ComputeDAG for the compute declaration. */ ComputeDAG compute_dag; - /*! \brief The workload key for target compute declaration. */ + /*! \brief The workload key for the compute declaration. */ String workload_key; /*! \brief The target device of this search task. */ Target target; @@ -135,8 +134,8 @@ class SearchTask : public ObjectRef { public: /*! * \brief The constructor. - * \param compute_dag The ComputeDAG for target compute declaration. - * \param workload_key The workload key for target compute declaration. + * \param compute_dag The ComputeDAG for the compute declaration. + * \param workload_key The workload key for the compute declaration. * \param target The target device of this search task. * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 8769714d0513..d3b5b39750c1 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -162,7 +162,7 @@ SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, const Arraystage_id = stage_id; // Extent can be a unreducible expression in some special cases if (extent->IsInstance()) { - node->extent = std::move(tvm::Downcast(extent)); + node->extent = tvm::Downcast(extent); } node->iter_id = iter_id; node->lengths = lengths; diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 140b7b0539f1..0c053693d9f5 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -21,7 +21,7 @@ * \file ansor/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform * step. The implementation of each step consists of 2 parts: - * - transform_step.cc: How each step interact with TVM system + * - transform_step.cc: How each step interact with TE and TE's schedule primitives * - loop_state.cc: How each step reflect on LoopState * * \note Adding a new transform step. @@ -34,8 +34,8 @@ * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. - * 5. Add serialization support in `struct Handler >` - * in `serialization.cc`. + * 5. Add log record serialization support in `struct Handler>` + * in `record.cc`. * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test. */ @@ -59,7 +59,7 @@ typedef Map, ObjectHash, ObjectEqual> StageT */ class StepNode : public Object { public: - /*! \brief The index of the target stage. */ + /*! \brief The index of the stage. */ int stage_id; static constexpr const char* _type_key = "ansor.Step"; @@ -111,8 +111,8 @@ class ReorderStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the target stage. - * \param after_ids The index of the iterators after reorder. + * \param stage_id The index of the stage to be reordered. + * \param after_ids The expected indexes of the iterators after reorder. */ ReorderStep(int stage_id, const Array& after_ids); @@ -166,9 +166,10 @@ class SplitStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the target stage. - * \param extent The index of the target iterator. - * \param lengths The extent length of the axis to split. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param extent The extent length of the axis to split. + * \param lengths The multiple split factors. Can be None to be filled by search policy. * \param inner_to_outer The split direction. */ SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, @@ -211,8 +212,8 @@ class FuseStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the target stage. - * \param fused_ids The index of the target iterators to be fused. + * \param stage_id The index of the stage to be fused. + * \param fused_ids The index of the iterators to be fused. */ FuseStep(int stage_id, const Array& fused_ids); diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 54e434804549..4b0d0a0b180a 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -115,14 +115,14 @@ inline double FloatArrayMean(const Array& float_array) { } /********** Other Utilities **********/ -/*! \brief Get an int value from an Expr */ +/*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { auto pint = expr.as(); CHECK(pint != nullptr); return pint->value; } -/*! \brief Compute the product of the lengths of axes */ +/*! \brief Compute the product of the lengths of axes */ inline int64_t AxisLengthProd(const Array& axes) { int64_t ret = 1.0; for (const auto& x : axes) { @@ -149,7 +149,7 @@ inline std::string CleanName(const std::string& str) { return ret; } -/*! \brief An empty output stream */ +/*! \brief An empty output stream */ class NullStream : public std::ostream { public: NullStream() : std::ostream(nullptr) {} diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index a21f70f0d956..5d100025bf1e 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Test measurement and log serialization""" +""" Test measurement and log serialization. """ import tvm from tvm import ansor @@ -24,7 +24,7 @@ from test_ansor_common import get_tiled_matmul -def test_serialization(): +def test_record(): dag, s = get_tiled_matmul() if not tvm.runtime.enabled("llvm"): @@ -36,9 +36,9 @@ def test_serialization(): res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) with tempfile.NamedTemporaryFile() as fp: - ansor.serialization.append_measure_records_to_file(fp.name, [inp], [res]) + ansor.record.append_measure_records_to_file(fp.name, [inp], [res]) - log_reader = ansor.serialization.LogReader(fp.name) + log_reader = ansor.record.LogReader(fp.name) inputs, results = log_reader.read_lines() assert len(inputs) == 1 @@ -68,5 +68,5 @@ def test_measure_local_builder_runner(): if __name__ == "__main__": - test_serialization() + test_record() test_measure_local_builder_runner() From d418a57ca9008b2882243221dab2ede64b364d9d Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 5 Jul 2020 13:15:39 +0800 Subject: [PATCH 67/78] Update --- python/tvm/ansor/loop_state.py | 3 +- python/tvm/ansor/measure.py | 87 +++++++++++++++++++++--- python/tvm/ansor/record.py | 13 ++-- src/ansor/auto_schedule.h | 5 +- src/ansor/compute_dag.cc | 66 ++++++------------ src/ansor/compute_dag.h | 10 ++- src/ansor/loop_state.cc | 66 ++++++------------ src/ansor/loop_state.h | 2 +- src/ansor/measure.cc | 15 ++-- src/ansor/search_policy/empty_policy.h | 3 +- src/ansor/search_policy/search_policy.cc | 8 +-- src/ansor/search_policy/search_policy.h | 2 +- src/ansor/transform_step.h | 2 +- 13 files changed, 150 insertions(+), 132 deletions(-) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 61ac371762a4..b407eadd7f3d 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -145,8 +145,7 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): lengths: List[int] The multiple split factors. Can be None to be filled by search policy. inner_to_outer: bool = True - True to use `factor` to split from inner to outer, - False to use `nparts` to split from outer to inner + Whether the factor go from inner to outer, or from outer to inner. Returns ------- diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 6c3410c87076..01e722c9a944 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -"""Distributed measurement infrastructure to measure the runtime costs of tensor programs +""" +Distributed measurement infrastructure to measure the runtime costs of tensor programs. These functions are responsible for building the tvm module, uploading it to remote devices, recording the running time costs, and checking the correctness of the output. -We implement these in python to utilize python's multiprocessing and error handling +We implement these in python to utilize python's multiprocessing and error handling. """ import os @@ -127,7 +128,7 @@ def build(self, measure_inputs, verbose=1): ---------- measure_inputs : List[MeasureInput] A List of MeasureInput. - verbost : int = 1 + verbose : int = 1 Verbosity level. 0 for silent, 1 to output information during program building. Returns @@ -150,7 +151,7 @@ def run(self, measure_inputs, build_results, verbose=1): A List of MeasureInput. build_results : List[BuildResult] A List of BuildResult to be ran. - verbost : int = 1 + verbose : int = 1 Verbosity level. 0 for silent, 1 to output information during program running. Returns @@ -245,7 +246,19 @@ def make_error_msg(): def local_build_worker(index): - """ Local builder function. """ + """ + Build function of LocalBuilder to be ran in the Builder thread pool. + + Parameters + ---------- + index : int + The MeasureInput index to be processed by the current Builder thread. + + Returns + ------- + res : BuildResult + The build result of this Builder thread. + """ # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool if not GLOBAL_BUILD_ARGUMENTS: @@ -311,8 +324,28 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") -def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): - """ Local builder build function. """ +def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=1): + """ + Build function of LocalBuilder to build the MeasureInputs to runnable modules. + + Parameters + ---------- + inputs : List[MeasureInput] + The MeasureInputs to be built. + timeout : int + The timeout limit for each build thread. + n_parallel : int + Number of threads used to build in parallel. + build_func : str = 'default' + The name of build function to process the built module. + verbose : int = 1 + Verbosity level. 0 for silent, 1 to output information during program building. + + Returns + ------- + res : List[BuildResult] + The build results of these MeasureInputs. + """ # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global GLOBAL_BUILD_ARGUMENTS @@ -332,8 +365,44 @@ def local_builder_build(inputs, timeout, n_parallel, build_func, verbose): @tvm._ffi.register_func("ansor.local_runner.run") def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, - verbose): - """ Local runner run function. """ + verbose=1): + """ + Run function of LocalRunner to test the performance of the input BuildResults. + + Parameters + ---------- + inputs : List[MeasureInput] + The MeasureInputs to be measured. + build_results : List[BuildResult] + The BuildResults to be measured. + timeout : int + The timeout limit for each build thread. + number : int = 3 + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int = 1 + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms : int = 0 + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval : float = 0.0 + The cool down interval between two measurements. + verbose : int = 1 + Verbosity level. 0 for silent, 1 to output information during program measuring. + + Returns + ------- + res : List[MeasureResult] + The measure results of these MeasureInputs. + """ max_float = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log def timed_func(inp, build_res): diff --git a/python/tvm/ansor/record.py b/python/tvm/ansor/record.py index 6f94105494aa..b541dbf7a110 100644 --- a/python/tvm/ansor/record.py +++ b/python/tvm/ansor/record.py @@ -52,15 +52,15 @@ class LogReader(Object): def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) - def read_lines(self, max_lines=-1, skip_lines=0): + def read_lines(self, max_lines=None, skip_lines=None): """ Read multiple lines from the log file. Parameters ---------- - max_lines : int = -1 - The maximum number of lines. -1 means to read all lines. - skip_lines : int = 0 - Skip the first n lines. + max_lines : Optional[int] + The maximum number of lines. None to read all lines. + skip_lines : Optional[int] + Skip the first n lines. None to read all lines. Returns ------- @@ -69,7 +69,8 @@ def read_lines(self, max_lines=-1, skip_lines=0): results : List[MeasureResult] The MeasureResults loaded from the log file. """ - inputs, results = _ffi_api.LogReaderReadLines(self, max_lines, skip_lines) + inputs, results = _ffi_api.LogReaderReadLines(self, max_lines if max_lines else -1, + skip_lines if skip_lines else 0) return inputs, results def __iter__(self): diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index c493c0dfc47c..991eda5eb8e8 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -104,8 +104,9 @@ class TuningOptions : public ObjectRef { * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or * `tvm.build`. */ -std::pair> AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuningOptions tuning_options); +TVM_DLL std::pair> AutoSchedule(SearchTask task, + SearchPolicy search_policy, + TuningOptions tuning_options); } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index b99b609be42b..5fdc7fa3563c 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -47,8 +47,7 @@ using namespace tvm::tir; TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Topo-sort ops from tensors according to their read-write relations. -// Results are stored in ops -void TopoSortOps(const Array& tensors, Array* ops) { +Array TopoSortOps(const Array& tensors) { std::unordered_map degree; std::unordered_map> edge_set; std::unordered_map priority; @@ -88,7 +87,7 @@ void TopoSortOps(const Array& tensors, Array* ops) { } // topo sort - ops->clear(); + Array ops; using Item = std::pair; auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; }; @@ -99,11 +98,11 @@ void TopoSortOps(const Array& tensors, Array* ops) { } } - ops->reserve(degree.size()); + ops.reserve(degree.size()); while (!queue.empty()) { Item item = queue.top(); queue.pop(); - ops->push_back(GetRef(item.first)); + ops.push_back(GetRef(item.first)); for (const auto& dst : edge_set[item.first]) { degree[dst] -= 1; if (degree[dst] == 0) { @@ -111,6 +110,8 @@ void TopoSortOps(const Array& tensors, Array* ops) { } } } + + return ops; } // Estimate number of float operations in an expression @@ -212,18 +213,15 @@ class FlopEstimator : public ExprFunctor { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); - FlopEstimator estimator; - Array ops; node->tensors = std::move(tensors); - TopoSortOps(node->tensors, &ops); - node->ops = std::move(ops); - node->flop_ct = estimator.EstimateFlop(node->ops); + node->ops = std::move(TopoSortOps(node->tensors)); + node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); } // Update the te::stage to tir::IterVar axis mapping -void UpdateStageAxis(const te::Stage& stage, StageToAxesMap* stage_to_axes) { +void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { if (auto pop = stage->op.as()) { Array axes; for (const auto& axis : pop->axis) { @@ -259,13 +257,13 @@ std::pair> ComputeDAG::ApplySteps( } } // Create the initial schedule - te::Schedule schedule = te::create_schedule({ops.back()}); + te::Schedule schedule = te::create_schedule(ops); // init axes for (const auto& x : operator->()->ops) { - const te::Stage& stage = schedule.operator[](x); + const te::Stage& stage = schedule[x]; stages->push_back(stage); - UpdateStageAxis(stage, stage_to_axes); + UpdateStageToAxesMap(stage, stage_to_axes); } // Use complete rate for the study in the paper @@ -307,13 +305,13 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } } // Create the initial schedule - te::Schedule schedule = te::create_schedule({ops.back()}); + te::Schedule schedule = te::create_schedule(ops); // init axes for (const auto& x : operator->()->ops) { - const te::Stage& stage = schedule.operator[](x); + const te::Stage& stage = schedule[x]; stages.push_back(stage); - UpdateStageAxis(stage, &stage_to_axes); + UpdateStageToAxesMap(stage, &stage_to_axes); } std::stringstream ss; @@ -351,16 +349,16 @@ State ComputeDAG::InferBound(const State& state) const { State ret_state; StateNode* pstate; - if (state->stages.size()) { - ret_state = state; - pstate = ret_state.CopyOnWrite(); - } else { + if (state->stages.empty()) { // If the input state is incomplete with empty operation stage // create a new state from init_state and update it first ret_state = operator->()->init_state; pstate = ret_state.CopyOnWrite(); pstate->transform_steps = state->transform_steps; - ret_state.DoSteps((*this)); + ret_state.DoSteps(*this); + } else { + ret_state = state; + pstate = ret_state.CopyOnWrite(); } Array stages; @@ -405,30 +403,6 @@ State ComputeDAG::InferBound(const State& state) const { return ret_state; } -void ComputeDAG::InferBound(Array* states) const { - Array out_states(states->size(), State()); - - auto worker_func = [&states, &out_states, this](int idx) { - try { - out_states.Set(idx, this->InferBound((*states)[idx])); - } catch (dmlc::Error& e) { - LOG(WARNING) << "InferBound fails on the state:\n" - << (*states)[idx] << "\n" - << e.what() << std::endl; - } - }; - - // Lower states in parallel - ThreadPool& pool = ThreadPool::Global(); - pool.BeginBatch(states->size()); - for (size_t i = 0; i < states->size(); ++i) { - pool.Enqueue(worker_func, i); - } - pool.WaitBatch(); - - *states = std::move(out_states); -} - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 66459bb70864..a8a1126e7a60 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -103,16 +103,14 @@ class ComputeDAG : public ObjectRef { /*! * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. + * The states can lose complete bound information after some transform steps (e.g., compute_at). + * We can call this function to infer and fill all the bound information. + * This function calls TVM InferBound pass internally to get the bound. + * The returned state of this function is guaranteed to have complete iterator extent information. * \param state The state to. * \return The State after inferbound. */ State InferBound(const State& state) const; - /*! - * \brief Fill the correct bound information for a list of given states. - * Return the new states inplace. - * \param states A pointer to a State Array, States are updated inplace. - */ - void InferBound(Array* states) const; TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index d792ffbd833f..9319c218fd87 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -295,8 +295,21 @@ void State::DoSteps(const ComputeDAG& dag) { } } +static const char* IteratorAnnotationString[] = { + "for", // kNone = 0 + "unroll", // kUnroll = 1 + "vectorize", // kVectorize = 2 + "parallel", // kParallel = 3 + "vthread", // kVThread = 4 + "gpu.blockIdx.x", // kBlockX = 5 + "gpu.threadIdx.x", // kThreadX = 6 + "gpu.blockIdx.y", // kBlockY = 7 + "gpu.threadIdx.y", // kThreadY = 8 + "tensorize" // kTensorized = 9 +}; + // Print stage to ostream -void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, +void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; @@ -321,41 +334,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } - switch (iter->annotation) { - case kNone: - *os << "for "; - break; - case kUnroll: - *os << "unroll "; - break; - case kParallel: - *os << "parallel "; - break; - case kVectorize: - *os << "vectorize "; - break; - case kVThread: - *os << "vthread "; - break; - case kBlockX: - *os << "gpu.blockIdx.x "; - break; - case kBlockY: - *os << "gpu.blockIdx.y "; - break; - case kThreadX: - *os << "gpu.threadIdx.x "; - break; - case kThreadY: - *os << "gpu.threadIdx.y "; - break; - case kTensorized: - *os << "tensorize "; - break; - default: - LOG(FATAL) << "Invalid Annotation " << iter->annotation; - break; - } + *os << IteratorAnnotationString[iter->annotation] << " "; if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")"; } else { @@ -374,10 +353,10 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b } // Print state to ostream -void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { +void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) { // Gather placeholders Array placeholders; - for (const auto& stage : node->stages) { + for (const auto& stage : state->stages) { if (stage->op_type == kPlaceholder) { placeholders.push_back(stage->op->name); } @@ -393,13 +372,13 @@ void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loo *os << "\n"; // Print all stages - for (size_t i = 0; i < node->stages.size(); ++i) { - const Stage& stage = node->stages[i]; + for (size_t i = 0; i < state->stages.size(); ++i) { + const Stage& stage = state->stages[i]; if (stage->op_type == kPlaceholder) { continue; } else if (stage->op_type == kCompute) { if (stage->compute_at == kRoot) { - PrintStage(os, i, node, 0, delete_trivial_loop); + PrintStage(os, i, state, 0, delete_trivial_loop); } } else { LOG(FATAL) << "Invalid op type"; @@ -409,14 +388,13 @@ void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loo String State::ToStr(bool delete_trivial_loop) const { std::ostringstream os; - PrintState(&os, operator->(), delete_trivial_loop); + PrintState(&os, (*this), delete_trivial_loop); return os.str(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); + PrintState(&p->stream, tvm::Downcast(ref), true); }); /********** State interface API for ffi **********/ diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index f91a5dbb45bb..424749cd4696 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -303,7 +303,7 @@ class State : public ObjectRef { * \param stage_id The index of the stage to be split. * \param it The iterator the be split. * \param lengths The multiple split factors. Can be None to be filled by search policy. - * \param inner_to_outer True for split from inner to outer & False for outer to inner. + * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. * \return The iterator results after split. */ Array split(int stage_id, const Iterator& it, const Array& lengths, diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 003e25d95aff..6bfa9c96c43e 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -114,10 +114,9 @@ Array LocalBuilderNode::Build(const Array& inputs, in if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; - } else { - LOG(FATAL) << "ansor.local_builder.build is not registered"; } - return Array(); + LOG(FATAL) << "ansor.local_builder.build is not registered"; + throw; } /********** LocalRunner **********/ @@ -138,10 +137,9 @@ Array LocalRunnerNode::Run(const Array& inputs, Array results = (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose); return results; - } else { - LOG(FATAL) << "ansor.local_runner.run is not registered"; } - return Array(); + LOG(FATAL) << "ansor.local_runner.run is not registered"; + throw; } /********** ProgramMeasurer **********/ @@ -237,11 +235,10 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Arrayclear(); results->reserve(inputs.size()); - Array input_batch(inputs.begin(), inputs.end()); // Call builder and runner - Array build_res_batch = builder->Build(input_batch, verbose); - Array result_batch = runner->Run(input_batch, build_res_batch, verbose); + Array build_res_batch = builder->Build(inputs, verbose); + Array result_batch = runner->Run(inputs, build_res_batch, verbose); // Store result batch for (auto& res : result_batch) { diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index 58209b666048..094e71913823 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -19,7 +19,8 @@ /*! * \file ansor/search_policy/empty_policy.h - * \brief This is an brief example of search policy. + * \brief A brief example of the search policy which always returns the initial naive schedule + * (state). */ #ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 70bed321f802..0b62efda79cc 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -32,16 +32,16 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -void SearchPolicyNode::RunCallbacks(const Array& callbacks) { - if (callbacks.defined() && callbacks.size()) { - for (const auto& callback : callbacks) { +void SearchPolicyNode::RunCallbacks(const Optional>& callbacks) { + if (callbacks.defined()) { + for (const auto& callback : callbacks.value()) { callback->Callback(this); } } } TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") - .set_body_typed([](SearchPolicy policy, Array callbacks) { + .set_body_typed([](SearchPolicy policy, Optional> callbacks) { policy->RunCallbacks(callbacks); }); diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 1ae601896ce0..2d551f1c5ce1 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -128,7 +128,7 @@ class SearchPolicyNode : public Object { * \brief Call SearchCallback with the current SearchPolicyNode * \param callbacks SearchCallback to be called. */ - void RunCallbacks(const Array& callbacks); + void RunCallbacks(const Optional>& callbacks); static constexpr const char* _type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 0c053693d9f5..3d40f35d184d 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -24,7 +24,7 @@ * - transform_step.cc: How each step interact with TE and TE's schedule primitives * - loop_state.cc: How each step reflect on LoopState * - * \note Adding a new transform step. + * \note To add a new transform step: * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction * function `FuseStep::FuseStep(...)` in `transform_steps.cc` From 8e1d65d94c411278505f29af0e8607ea6bd027e8 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 5 Jul 2020 23:28:34 +0800 Subject: [PATCH 68/78] Update --- python/tvm/ansor/auto_schedule.py | 12 +++++++++--- python/tvm/ansor/measure.py | 12 ++++++++---- python/tvm/ansor/record.py | 8 ++++---- src/ansor/compute_dag.cc | 2 +- src/ansor/loop_state.cc | 23 ++++++----------------- src/ansor/loop_state.h | 12 +----------- src/ansor/measure.cc | 8 ++++++-- src/ansor/measure.h | 4 ++-- src/ansor/record.cc | 16 +++++++--------- 9 files changed, 44 insertions(+), 53 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 86c4adeabdb2..9b58e81855df 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -39,7 +39,7 @@ class HardwareParams(Object): """ The parameters of target hardware used to guide the search process of SearchPolicy. - TODO(jcf94): This is considering to merge with the new Target: + TODO(jcf94): This is considered to be merged with the new Target: https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 Parameters @@ -138,15 +138,21 @@ def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_rou builder = LocalBuilder() else: raise ValueError("Invalid builder: " + builder) + elif not isinstance(builder, tvm.ansor.measure.ProgramBuilder): + raise ValueError("Invalid builder: " + builder + + " . TuningOptions expects a ProgramBuilder or string.") if isinstance(runner, str): if runner == 'local': runner = LocalRunner() else: raise ValueError("Invalid runner: " + runner) + elif not isinstance(runner, tvm.ansor.measure.ProgramRunner): + raise ValueError("Invalid runner: " + runner + + " . TuningOptions expects a ProgramRunner or string.") - measure_callbacks = [] if measure_callbacks is None else measure_callbacks - pre_search_callbacks = [] if pre_search_callbacks is None else pre_search_callbacks + measure_callbacks = measure_callbacks if measure_callbacks else [] + pre_search_callbacks = pre_search_callbacks if pre_search_callbacks else [] self.__init_handle_by_constructor__( _ffi_api.TuningOptions, num_measure_trials, early_stopping, num_measures_per_round, diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 01e722c9a944..536919f04bcc 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -168,7 +168,8 @@ class LocalBuilder(ProgramBuilder): Parameters ---------- timeout : int = 15 - The timeout limit for each build. + The timeout limit (in second) for each build thread. + This is used in a wrapper of the multiprocessing.Process.join(). n_parallel : int = multiprocessing.cpu_count() Number of threads used to build in parallel. build_func : str = 'default' @@ -190,7 +191,8 @@ class LocalRunner(ProgramRunner): Parameters ---------- timeout : int = 10 - The timeout limit for each run. + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). number : int = 3 The number of times to run the generated code for taking average. We call these runs as one `repeat` of measurement. @@ -333,7 +335,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo inputs : List[MeasureInput] The MeasureInputs to be built. timeout : int - The timeout limit for each build thread. + The timeout limit (in second) for each build thread. + This is used in a wrapper of the multiprocessing.Process.join(). n_parallel : int Number of threads used to build in parallel. build_func : str = 'default' @@ -376,7 +379,8 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo build_results : List[BuildResult] The BuildResults to be measured. timeout : int - The timeout limit for each build thread. + The timeout limit (in second) for each run. + This is used in a wrapper of the multiprocessing.Process.join(). number : int = 3 The number of times to run the generated code for taking average. We call these runs as one `repeat` of measurement. diff --git a/python/tvm/ansor/record.py b/python/tvm/ansor/record.py index b541dbf7a110..8770252dca2a 100644 --- a/python/tvm/ansor/record.py +++ b/python/tvm/ansor/record.py @@ -52,15 +52,15 @@ class LogReader(Object): def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) - def read_lines(self, max_lines=None, skip_lines=None): + def read_lines(self, max_lines=None, skip_lines=0): """ Read multiple lines from the log file. Parameters ---------- max_lines : Optional[int] The maximum number of lines. None to read all lines. - skip_lines : Optional[int] - Skip the first n lines. None to read all lines. + skip_lines : int = 0 + Skip the first n lines. Returns ------- @@ -70,7 +70,7 @@ def read_lines(self, max_lines=None, skip_lines=None): The MeasureResults loaded from the log file. """ inputs, results = _ffi_api.LogReaderReadLines(self, max_lines if max_lines else -1, - skip_lines if skip_lines else 0) + skip_lines) return inputs, results def __iter__(self): diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 5fdc7fa3563c..67ba7f34ccce 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -214,7 +214,7 @@ class FlopEstimator : public ExprFunctor { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); node->tensors = std::move(tensors); - node->ops = std::move(TopoSortOps(node->tensors)); + node->ops = TopoSortOps(node->tensors); node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 9319c218fd87..54d29137f4d5 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -88,17 +88,6 @@ Stage::Stage(te::Operation op, StageType op_type, const Array& iters, data_ = std::move(node); } -Stage::Stage(te::Operation op, StageType op_type, Array&& iters, ComputeAtType compute_at, - StageAttributes attrs) { - auto node = make_object(); - node->op = std::move(op); - node->op_type = op_type; - node->iters = std::move(iters); - node->compute_at = compute_at; - node->attrs = attrs; - data_ = std::move(node); -} - /********** State **********/ State::State(const Array& ops) { auto node = make_object(); @@ -148,8 +137,8 @@ void State::DoReorderStep(const ReorderStep& step) { iters.push_back(stage->iters[x]); } StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), - stage->compute_at, stage->attrs)); + pstate->stages.Set(step->stage_id, + Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -209,8 +198,8 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array< new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), - stage->compute_at, stage->attrs)); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); return outs; } @@ -263,8 +252,8 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), - stage->compute_at, stage->attrs)); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); return new_it; } diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 424749cd4696..5db611448fd1 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -206,22 +206,12 @@ class Stage : public ObjectRef { * \brief The constructor. * \param op A `te::Operation`. * \param op_type The stage type of this op. - * \param iters The iterators of this op. (copy) + * \param iters The iterators of this op. * \param compute_at The compute at type of this op. * \param attrs Other stage-level attributes. */ Stage(te::Operation op, StageType op_type, const Array& iters, ComputeAtType compute_at, StageAttributes attrs); - /*! - * \brief The constructor. - * \param op A `te::Operation`. - * \param op_type The stage type of this op. - * \param iters The iterators of this op. (move) - * \param compute_at The compute at type of this op. - * \param attrs Other stage-level attributes. - */ - Stage(te::Operation op, StageType op_type, Array&& iters, ComputeAtType compute_at, - StageAttributes attrs); TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 6bfa9c96c43e..9e1e15dd0830 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -115,7 +115,9 @@ Array LocalBuilderNode::Build(const Array& inputs, in Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; } - LOG(FATAL) << "ansor.local_builder.build is not registered"; + LOG(FATAL) << "ansor.local_builder.build is not registered. " + << "This is a function registered in Python, " + << "make sure the TVM Python runtime has been loaded successfully."; throw; } @@ -138,7 +140,9 @@ Array LocalRunnerNode::Run(const Array& inputs, min_repeat_ms, cooldown_interval, verbose); return results; } - LOG(FATAL) << "ansor.local_runner.run is not registered"; + LOG(FATAL) << "ansor.local_runner.run is not registered. " + << "This is a function registered in Python, " + << "make sure the TVM Python runtime has been loaded successfully."; throw; } diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 3442e8b3e18f..e441c931a9e7 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -77,7 +77,7 @@ class MeasureInputNode : public Object { v->Visit("state", &state); } - /*! \brief Do deep copy. */ + /*! \brief Do shallow copy. */ MeasureInput copy() const; static constexpr const char* _type_key = "ansor.MeasureInput"; @@ -167,7 +167,7 @@ class MeasureResultNode : public Object { v->Visit("timestamp", ×tamp); } - /*! \brief Do deep copy. */ + /*! \brief Do shallow copy. */ MeasureResult copy() const; static constexpr const char* _type_key = "ansor.MeasureResult"; diff --git a/src/ansor/record.cc b/src/ansor/record.cc index b99a67f4e64c..e82ba66c8f3a 100644 --- a/src/ansor/record.cc +++ b/src/ansor/record.cc @@ -42,14 +42,13 @@ namespace dmlc { namespace json { -inline std::vector& IntArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::Integer>& data) { - out->clear(); +inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) { + std::vector out; for (const auto& x : data) { CHECK(x.defined()); - out->push_back(x); + out.push_back(x); } - return *out; + return out; } template <> @@ -70,7 +69,6 @@ struct Handler<::tvm::Array<::tvm::ansor::Stage>> { template <> struct Handler<::tvm::Array<::tvm::ansor::Step>> { inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::ansor::Step>& data) { - std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { writer->WriteArraySeperator(); @@ -78,18 +76,18 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { writer->WriteArrayItem(std::string("RE")); writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(IntArrayToVector(&tmp, ps->after_ids)); + writer->WriteArrayItem(IntArrayToVector(ps->after_ids)); } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->extent.defined() ? ::tvm::ansor::GetIntImm(ps->extent) : 0); - writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); + writer->WriteArrayItem(IntArrayToVector(ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(IntArrayToVector(&tmp, ps->fused_ids)); + writer->WriteArrayItem(IntArrayToVector(ps->fused_ids)); } else { LOG(FATAL) << "Invalid step: " << data[i]; } From 3a67a72c58fb9667095933d5d73461e40d91ad80 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 9 Jul 2020 11:08:46 +0800 Subject: [PATCH 69/78] Update --- python/tvm/ansor/__init__.py | 6 +- python/tvm/ansor/auto_schedule.py | 12 +-- python/tvm/ansor/compute_dag.py | 6 +- python/tvm/ansor/loop_state.py | 23 +---- python/tvm/ansor/measure.py | 10 ++ .../ansor/{record.py => measure_record.py} | 30 +++--- python/tvm/ansor/utils.py | 2 +- python/tvm/ansor/workload_registry.py | 23 +++-- src/ansor/auto_schedule.cc | 8 +- src/ansor/auto_schedule.h | 8 +- src/ansor/compute_dag.cc | 35 +++---- src/ansor/compute_dag.h | 8 +- src/ansor/loop_state.cc | 97 +++++++++---------- src/ansor/loop_state.h | 48 +++++---- src/ansor/measure.cc | 8 +- src/ansor/measure.h | 13 ++- src/ansor/{record.cc => measure_record.cc} | 62 +++++++----- src/ansor/{record.h => measure_record.h} | 42 ++++---- src/ansor/search_policy/empty_policy.cc | 2 +- src/ansor/search_policy/empty_policy.h | 2 +- src/ansor/search_policy/search_policy.cc | 2 +- src/ansor/search_policy/search_policy.h | 12 ++- src/ansor/search_task.cc | 8 +- src/ansor/search_task.h | 2 +- src/ansor/transform_step.cc | 23 ++--- src/ansor/transform_step.h | 12 +-- tests/python/unittest/test_ansor_measure.py | 4 +- .../unittest/test_ansor_search_policy.py | 4 +- 28 files changed, 272 insertions(+), 240 deletions(-) rename python/tvm/ansor/{record.py => measure_record.py} (83%) rename src/ansor/{record.cc => measure_record.cc} (87%) rename src/ansor/{record.h => measure_record.h} (77%) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 4baba87e1231..5a2f210938f7 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -19,7 +19,7 @@ from . import compute_dag from . import measure -from . import record +from . import measure_record from . import loop_state from . import utils from . import workload_registry @@ -29,6 +29,6 @@ from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ auto_schedule, EmptyPolicy from .measure import MeasureInput, LocalBuilder, LocalRunner -from .record import LogToFile, LogReader, best_measure_pair_in_file, \ - load_from_file, append_measure_records_to_file +from .measure_record import RecordToFile, RecordReader, load_best, \ + load_records, save_records from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 9b58e81855df..1f5bdf8f14d9 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -107,7 +107,7 @@ class TuningOptions(Object): With `num_measure_trials` == 0, the policy will do the schedule search but won't involve measurement. This can be used to get a runnable schedule quickly without auto-tuning. - early_stopping: int = -1 + early_stopping: Optional[int] Stop the tuning early if getting no improvement after n measurements. num_measures_per_round: int = 64 The number of schedules to be measured at each search round. @@ -122,7 +122,7 @@ class TuningOptions(Object): measure_callbacks: Optional[List[MeasureCallback]] Callback functions called after each measurement. Candidates: - - ansor.LogToFile + - ansor.RecordToFile pre_search_callbacks: Optional[List[SearchCallback]] Callback functions called before the search process. Candidates: @@ -130,7 +130,7 @@ class TuningOptions(Object): - ansor.PreloadCustomSketchRule TODO(jcf94): Add these implementation in later PRs. """ - def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_round=64, + def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64, verbose=1, builder='local', runner='local', measure_callbacks=None, pre_search_callbacks=None): if isinstance(builder, str): @@ -151,11 +151,9 @@ def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_rou raise ValueError("Invalid runner: " + runner + " . TuningOptions expects a ProgramRunner or string.") - measure_callbacks = measure_callbacks if measure_callbacks else [] - pre_search_callbacks = pre_search_callbacks if pre_search_callbacks else [] - self.__init_handle_by_constructor__( - _ffi_api.TuningOptions, num_measure_trials, early_stopping, num_measures_per_round, + _ffi_api.TuningOptions, num_measure_trials, early_stopping if early_stopping else -1, + num_measures_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index b72c8649133e..f52c3e2d3192 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Computational graph and its analysis tools """ +""" The Ansor computational graph and related program analyses. """ import hashlib @@ -92,8 +92,8 @@ def print_python_code_from_state(self, state): """ Print transform steps in the history of a State as TVM's python schedule primitive. - This can be used for debugging or to apply the schedule on a former TVM version without - Ansor support. + This is used to print transformation steps for debugging. + Use `apply_steps_from_state` if you want to get a schedule for code generation. Parameters ---------- diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index b407eadd7f3d..324cc20d48ed 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -87,7 +87,6 @@ def __init__(self, state_object, dag): self.state_object = state_object self.compute_dag = dag - self.stages_cache = None # A list to cache all stages self.stage_id_map = {} # A dict maps operation to stage id self._update_stage_id_map() @@ -98,9 +97,7 @@ def stages(self): ------- stages : List[Stage] """ - if not self.stages_cache: - self.stages_cache = self.state_object.stages - return self.stages_cache + return self.state_object.stages @property def stage_ops(self): @@ -109,9 +106,7 @@ def stage_ops(self): ------- ops: List[Operation] """ - if not self.stages_cache: - self.stages_cache = self.state_object.stages - return [stage.op for stage in self.stages_cache] + return [stage.op for stage in self.stages] def reorder(self, stage, order): """ Schedule primitive corresponds to te.reorder. @@ -127,7 +122,6 @@ def reorder(self, stage, order): stage_id = self._resolve_stage_id(stage) self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) - self._clear_cache() def split(self, stage, iterator, lengths, inner_to_outer=True): """ Schedule primitive corresponds to te.split. @@ -156,7 +150,6 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) - self._clear_cache() return res def fuse(self, stage, iters): @@ -178,7 +171,6 @@ def fuse(self, stage, iters): stage_id = self._resolve_stage_id(stage) self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) - self._clear_cache() return res def copy(self): @@ -198,21 +190,14 @@ def _resolve_stage_id(self, stage_id): " . Expect to be a int, Operation or Tensor") def _update_stage_id_map(self): - if not self.stages_cache: - self.stages_cache = self.state_object.stages - for index, stage in enumerate(self.stages_cache): + for index, stage in enumerate(self.stages): self.stage_id_map[stage.op] = index - def _clear_cache(self): - self.stages_cache = None - def __getitem__(self, key): - if not self.stages_cache: - self.stages_cache = self.state_object.stages if isinstance(key, Tensor): key = key.op if isinstance(key, Operation): - return self.stages_cache[self.stage_id_map[key]] + return self.stages[self.stage_id_map[key]] raise ValueError("Invalid item: " + key + " . Expect to be a Operation or Tensor") diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 536919f04bcc..2fe1169ad3fa 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -21,6 +21,13 @@ These functions are responsible for building the tvm module, uploading it to remote devices, recording the running time costs, and checking the correctness of the output. +We separate the measurement into two steps: build and run. +A builder builds the executable binary files and a runner runs the binary files to +get the measurement results. The flow of data structures is + + `ProgramBuilder` `ProgramRunner` +`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` + We implement these in python to utilize python's multiprocessing and error handling. """ @@ -261,6 +268,8 @@ def local_build_worker(index): res : BuildResult The build result of this Builder thread. """ + global GLOBAL_BUILD_ARGUMENTS + # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool if not GLOBAL_BUILD_ARGUMENTS: @@ -352,6 +361,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo # We use fork and a global variable to copy arguments between processings. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global GLOBAL_BUILD_ARGUMENTS + GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose) pool = NoDaemonPool(n_parallel) diff --git a/python/tvm/ansor/record.py b/python/tvm/ansor/measure_record.py similarity index 83% rename from python/tvm/ansor/record.py rename to python/tvm/ansor/measure_record.py index 8770252dca2a..46b94be18719 100644 --- a/python/tvm/ansor/record.py +++ b/python/tvm/ansor/measure_record.py @@ -25,8 +25,8 @@ from . import _ffi_api -@tvm._ffi.register_object("ansor.LogToFile") -class LogToFile(MeasureCallback): +@tvm._ffi.register_object("ansor.RecordToFile") +class RecordToFile(MeasureCallback): """ A measurement callback that writes measurement records into a file. @@ -36,11 +36,11 @@ class LogToFile(MeasureCallback): File name for this callback to write log to. """ def __init__(self, filename="ansor_tuning.json"): - self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) + self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename) -@tvm._ffi.register_object("ansor.LogReader") -class LogReader(Object): +@tvm._ffi.register_object("ansor.RecordReader") +class RecordReader(Object): """ Reader of the json log file. @@ -50,7 +50,7 @@ class LogReader(Object): File name for this reader to load log from. """ def __init__(self, filename="ansor_tuning.json"): - self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) + self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename) def read_lines(self, max_lines=None, skip_lines=0): """ Read multiple lines from the log file. @@ -69,19 +69,19 @@ def read_lines(self, max_lines=None, skip_lines=0): results : List[MeasureResult] The MeasureResults loaded from the log file. """ - inputs, results = _ffi_api.LogReaderReadLines(self, max_lines if max_lines else -1, - skip_lines) + inputs, results = _ffi_api.RecordReaderReadLines(self, max_lines if max_lines else -1, + skip_lines) return inputs, results def __iter__(self): while True: - ret = _ffi_api.LogReaderReadNext(self) + ret = _ffi_api.RecordReaderReadNext(self) if not ret: break yield ret[0], ret[1] # (input, result) -def load_from_file(filename): +def load_records(filename): """ Load measurement records from a file. @@ -94,10 +94,10 @@ def load_from_file(filename): ------- logs : List[MeasureInput, MeasureResult] """ - return zip(*LogReader(filename).read_lines()) + return zip(*RecordReader(filename).read_lines()) -def append_measure_records_to_file(filename, inputs, results): +def save_records(filename, inputs, results): """ Append measure records to file. @@ -110,9 +110,9 @@ def append_measure_records_to_file(filename, inputs, results): results: List[MeasureResults] The MeasureResults to be written. """ - _ffi_api.AppendMeasureRecordsToFile(filename, inputs, results) + _ffi_api.SaveRecords(filename, inputs, results) -def best_measure_pair_in_file(filename, workload_key=None, target=None): +def load_best(filename, workload_key=None, target=None): """ Return the best measurement pair form a log file. This may return none results if there is no legal measure pair with the specified workload_key/target found from the log file. @@ -134,7 +134,7 @@ def best_measure_pair_in_file(filename, workload_key=None, target=None): result : MeasureResult The best State's MeasureResult from this log fine. """ - log_reader = LogReader(filename) + log_reader = RecordReader(filename) best_cost = 1e30 best_inp = None best_res = None diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index c698812e54c7..6052ef626033 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -47,7 +47,7 @@ def get_func_name(func): name: str The function name. """ - return func.func_name if hasattr(func, 'func_name') else func.__name__ + return func.func_name if hasattr(func, 'func_name') else func.__qualname__ def get_const_int(exp): diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 5726ae3a7507..afaf31d5d874 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -25,7 +25,7 @@ Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags and matching them efficiently is not easy. Therefore, we use the above string to encode a compute dag. -These strings are efficient for serialization/matching and wont' be too long. +These strings are efficient for serialization/matching and won't be too long. When we need the dag, we decode the string and call the function, which will return the dag. """ @@ -33,13 +33,13 @@ import json import tvm._ffi -from .utils import serialize_args, deserialize_args +from .utils import serialize_args, deserialize_args, get_func_name WORKLOAD_FUNC_REGISTRY = {} def register_workload(func): - """ Register a workload by generation function. + """ Register a function that generates a certain workload. The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. @@ -59,8 +59,10 @@ def matmul(N, M, K): C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] """ + global WORKLOAD_FUNC_REGISTRY assert callable(func) - func_name = func.__name__ + + func_name = get_func_name(func) if func_name in WORKLOAD_FUNC_REGISTRY: raise RuntimeError('%s has been registered already' % func_name) @@ -69,7 +71,7 @@ def matmul(N, M, K): def make_workload_key(func, args): - """ make a workload key from function and arguments. + """ Make a workload key by function and arguments. Parameters ---------- @@ -84,12 +86,15 @@ def make_workload_key(func, args): workload_key : Str The workload key of the function. """ + global WORKLOAD_FUNC_REGISTRY + if callable(func): - func_name = func.__name__ + func_name = get_func_name(func) elif isinstance(func, str): func_name = func else: - raise ValueError("Invalid function: " + str(func)) + raise ValueError("Invalid function: " + str(func) + + " . `make_workload_key` expects a callable function or its function name") if not func_name in WORKLOAD_FUNC_REGISTRY: raise ValueError("%s is not registered. " % func, @@ -115,6 +120,8 @@ def decode_workload_key_to_func_args(workload_key): args : List[Tensor] The args of the generation function. """ + global WORKLOAD_FUNC_REGISTRY + workload = json.loads(workload_key) if not workload[0] in WORKLOAD_FUNC_REGISTRY: raise ValueError("%s is not registered. " % workload[0] + @@ -138,6 +145,8 @@ def workload_key_to_tensors(workload_key): tensors : List[Tensor] The registered compute declaration Tensors. """ + global WORKLOAD_FUNC_REGISTRY + name, args = decode_workload_key_to_func_args(workload_key) lookup = WORKLOAD_FUNC_REGISTRY[name] assert callable(lookup) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index ee51fc9a0210..43e7d9f85ccc 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -33,8 +33,8 @@ TVM_REGISTER_NODE_TYPE(TuningOptionsNode); TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, - Array measure_callbacks, - Array pre_search_callbacks) { + Optional> measure_callbacks, + Optional> pre_search_callbacks) { auto node = make_object(); node->num_measure_trials = num_measure_trials; node->early_stopping = early_stopping; @@ -64,8 +64,8 @@ std::pair> AutoSchedule(SearchTask task, SearchP TVM_REGISTER_GLOBAL("ansor.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, - Array measure_callbacks, - Array pre_search_callbacks) { + Optional> measure_callbacks, + Optional> pre_search_callbacks) { return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, builder, runner, measure_callbacks, pre_search_callbacks); }); diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 991eda5eb8e8..9e705cd350c4 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -51,9 +51,9 @@ class TuningOptionsNode : public Object { /*! \brief ProgramRunner which runs the program and measure time costs */ ProgramRunner runner; /*! \brief MeasureCallback functions to be called after each measure batch */ - Array measure_callbacks; + Optional> measure_callbacks; /*! \brief SearchCallback functions to be called before schedule search */ - Array pre_search_callbacks; + Optional> pre_search_callbacks; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_measure_trials", &num_measure_trials); @@ -90,8 +90,8 @@ class TuningOptions : public ObjectRef { */ TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, - Array measure_callbacks, - Array pre_search_callbacks); + Optional> measure_callbacks, + Optional> pre_search_callbacks); TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode); }; diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 67ba7f34ccce..98b4ed3c7df7 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -123,7 +123,7 @@ class FlopEstimator : public ExprFunctor { if (auto pop = op.as()) { double num_element = AxisLengthProd(pop->axis); if (num_element == -1) { - fail = true; + fail_ = true; break; } double op_per_element = 0; @@ -138,7 +138,7 @@ class FlopEstimator : public ExprFunctor { } } - return fail ? -1 : ret; + return fail_ ? -1 : ret; } double VisitExpr_(const ReduceNode* op) final { @@ -147,7 +147,7 @@ class FlopEstimator : public ExprFunctor { if (auto imm = x->dom->extent.as()) { num_iter *= imm->value; } else { - fail = true; + fail_ = true; num_iter = -1; } } @@ -204,11 +204,12 @@ class FlopEstimator : public ExprFunctor { } double VisitExprDefault_(const Object* op) final { - fail = true; + fail_ = true; return -1.0; } - bool fail{false}; + private: + bool fail_{false}; }; ComputeDAG::ComputeDAG(Array tensors) { @@ -257,7 +258,9 @@ std::pair> ComputeDAG::ApplySteps( } } // Create the initial schedule - te::Schedule schedule = te::create_schedule(ops); + // TODO(jcf94): Currently we only checked single output dag for Ansor, + // update this after testing with multiple outputs. + te::Schedule schedule = te::create_schedule({ops.back()}); // init axes for (const auto& x : operator->()->ops) { @@ -266,18 +269,8 @@ std::pair> ComputeDAG::ApplySteps( UpdateStageToAxesMap(stage, stage_to_axes); } - // Use complete rate for the study in the paper - const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); - double complete_rate = -1.0; - if (complete_rate_str) { - complete_rate = std::stod(complete_rate_str); - } - size_t ct = 0; // Apply the history steps to TVM schedule for (const auto& step : transform_steps) { - if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { - break; - } // Call each step's ApplyToSchedule method // Note: some steps have extra parameters that must be passed and they may need different // return value, so the ApplyToSchedule is not able to be merged to single interface @@ -305,7 +298,9 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } } // Create the initial schedule - te::Schedule schedule = te::create_schedule(ops); + // TODO(jcf94): Currently we only checked single output dag for Ansor, + // update this after testing with multiple outputs. + te::Schedule schedule = te::create_schedule({ops.back()}); // init axes for (const auto& x : operator->()->ops) { @@ -346,6 +341,8 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } State ComputeDAG::InferBound(const State& state) const { + CHECK(state->concrete) << "Only concrete state can be processed to get bound info."; + State ret_state; StateNode* pstate; @@ -375,7 +372,7 @@ State ComputeDAG::InferBound(const State& state) const { for (size_t i = 0; i < pstate->stages.size(); ++i) { const Stage& stage = pstate->stages[i]; - if (stage->compute_at == kInlined) { + if (stage->compute_at == ComputeAtKind::kInlined) { continue; } @@ -390,7 +387,7 @@ State ComputeDAG::InferBound(const State& state) const { auto find_res = bounds.find(axis); if (find_res != bounds.end()) { new_iters.push_back( - Iterator(iter->name, (*find_res).second, iter->iter_type, iter->annotation)); + Iterator(iter->name, (*find_res).second, iter->iter_kind, iter->annotation)); } else { LOG(FATAL) << "Infer bound fails"; } diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index a8a1126e7a60..8c244bd87778 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -47,7 +47,10 @@ namespace ansor { /*! \brief The Ansor computational graph and related program analyses. */ class ComputeDAGNode : public Object { public: - /*! \brief Input and output tensors. */ + /*! + * \brief Input and output tensors. + * This is used as the input of `tvm.lower` or `tvm.build`. + */ Array tensors; /*! \brief All related operations in topo order. */ Array ops; @@ -94,8 +97,7 @@ class ComputeDAG : public ObjectRef { /*! * \brief Print transform steps as equivalent python schedule API. - * This can be used for debugging or to apply the schedule on a former TVM version without Ansor - * support. + * This can be used for debugging. * \param transform_steps Transform steps of a state. * \return The Python schedule code. */ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 54d29137f4d5..00d2bc759eb6 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -42,12 +42,12 @@ TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); /********** Iterator **********/ -Iterator::Iterator(String name, Range range, IteratorType iter_type, +Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); - node->iter_type = iter_type; + node->iter_kind = iter_kind; node->annotation = annotation; data_ = std::move(node); } @@ -56,29 +56,31 @@ Iterator::Iterator(String name, Range range, IteratorType iter_type, Stage::Stage(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { - node->op_type = kCompute; + node->op_type = StageKind::kCompute; auto* pop = op.as(); for (const auto& axis : pop->axis) { - node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, + IteratorKind::kSpatial, IteratorAnnotation::kNone)); } for (const auto& axis : pop->reduce_axis) { - node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, + IteratorKind::kReduction, IteratorAnnotation::kNone)); } } else if (op->IsInstance()) { - node->op_type = kPlaceholder; + node->op_type = StageKind::kPlaceholder; } else { LOG(FATAL) << "Unsupported operator type" << op->_type_key; } - node->compute_at = kRoot; + node->compute_at = ComputeAtKind::kRoot; node->op = std::move(op); node->attrs.auto_unroll_max_step = 0; node->attrs.storage_offset = 0; data_ = std::move(node); } -Stage::Stage(te::Operation op, StageType op_type, const Array& iters, - ComputeAtType compute_at, StageAttributes attrs) { +Stage::Stage(te::Operation op, StageKind op_type, const Array& iters, + ComputeAtKind compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -94,7 +96,7 @@ State::State(const Array& ops) { for (const auto& op : ops) { node->stages.push_back(Stage(op)); } - node->complete = true; + node->concrete = true; data_ = std::move(node); } @@ -110,8 +112,8 @@ void State::reorder(int stage_id, const Array& order) { DoReorderStep(step); } -Array State::split(int stage_id, const Iterator& it, const Array& lengths, - bool inner_to_outer) { +Array State::split(int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStep(stage_id, GetIndex(stage->iters, it), @@ -142,22 +144,25 @@ void State::DoReorderStep(const ReorderStep& step) { } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, +Array State::DoSplitStepCommon(int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; + bool concrete = true; - PrimExpr tosplit_min, tosplit_extent; + Optional tosplit_min, tosplit_extent; if (it->range.defined()) { tosplit_min = it->range->min; tosplit_extent = it->range->extent; } else { - tosplit_min = tosplit_extent = PrimExpr(); + tosplit_min = NullOpt; + tosplit_extent = NullOpt; } Array outs; for (size_t i = 0; i < lengths.size(); ++i) { - PrimExpr l; + Optional l; String name; if (inner_to_outer) { l = lengths[lengths.size() - i - 1]; @@ -167,29 +172,32 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array< name = it->name + "." + std::to_string(i); } Iterator res; - if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { - res = Iterator(name, Range::FromMinExtent(tosplit_min, l), it->iter_type, kNone); - tosplit_min = 0; - tosplit_extent = indexdiv(tosplit_extent + l - 1, l); + if (l && tosplit_min && tosplit_extent) { + res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, + IteratorAnnotation::kNone); + tosplit_min = Integer(0); + tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); } else { - res = Iterator(name, Range(), it->iter_type, kNone); - tosplit_min = tosplit_extent = PrimExpr(); + res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + concrete = false; } outs.push_back(std::move(res)); } Range range; - if (tosplit_min.defined() && tosplit_extent.defined()) { - range = Range::FromMinExtent(tosplit_min, tosplit_extent); + if (tosplit_min && tosplit_extent) { + range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); } if (inner_to_outer) { - outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone)); + outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); // Reverse the Iterator array Array temp(outs.rbegin(), outs.rend()); outs = std::move(temp); } else { - outs.push_back( - Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_type, kNone)); + outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, + IteratorAnnotation::kNone)); } Array new_iters; @@ -200,6 +208,7 @@ Array State::DoSplitStepCommon(int stage_id, int iter_id, const Array< StateNode* pstate = CopyOnWrite(); pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + pstate->concrete &= concrete; return outs; } @@ -214,7 +223,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { String new_name; PrimExpr new_extent = 1; - IteratorType new_iter_type = kSpecial; + IteratorKind new_iter_kind = IteratorKind::kSpecial; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { @@ -231,10 +240,10 @@ Iterator State::DoFuseStep(const FuseStep& step) { } if (i == 0) { - new_iter_type = it->iter_type; + new_iter_kind = it->iter_kind; } else { - if (new_iter_type != it->iter_type) { - new_iter_type = kMixed; + if (new_iter_kind != it->iter_kind) { + new_iter_kind = IteratorKind::kMixed; } } } @@ -243,7 +252,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::FromMinExtent(0, new_extent); } - Iterator new_it = Iterator(new_name, range, new_iter_type, kNone); + Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); Array new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); @@ -261,17 +270,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { void State::DoSteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; - // Use complete rate for the study in the paper - const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); - double complete_rate = -1.0; - if (complete_rate_str) { - complete_rate = std::stod(complete_rate_str); - } - size_t ct = 0; for (const auto& step : operator->()->transform_steps) { - if (complete_rate >= 0 && ct++ > operator->()->transform_steps.size() * complete_rate) { - break; - } if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); } else if (auto ps = step.as()) { @@ -323,7 +322,7 @@ void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_ for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } - *os << IteratorAnnotationString[iter->annotation] << " "; + *os << IteratorAnnotationString[static_cast(iter->annotation)] << " "; if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")"; } else { @@ -346,7 +345,7 @@ void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) // Gather placeholders Array placeholders; for (const auto& stage : state->stages) { - if (stage->op_type == kPlaceholder) { + if (stage->op_type == StageKind::kPlaceholder) { placeholders.push_back(stage->op->name); } } @@ -363,10 +362,10 @@ void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) // Print all stages for (size_t i = 0; i < state->stages.size(); ++i) { const Stage& stage = state->stages[i]; - if (stage->op_type == kPlaceholder) { + if (stage->op_type == StageKind::kPlaceholder) { continue; - } else if (stage->op_type == kCompute) { - if (stage->compute_at == kRoot) { + } else if (stage->op_type == StageKind::kCompute) { + if (stage->compute_at == ComputeAtKind::kRoot) { PrintStage(os, i, state, 0, delete_trivial_loop); } } else { @@ -394,8 +393,8 @@ TVM_REGISTER_GLOBAL("ansor.StateReorder") }); TVM_REGISTER_GLOBAL("ansor.StateSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, const Array& lengths, - bool inner_to_outer) { + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { const auto& res = state.split(stage_id, it, lengths, inner_to_outer); return Array{state, res}; }); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 5db611448fd1..c91a65cee528 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -62,7 +62,7 @@ using namespace tvm::tir; class ComputeDAG; /*! \brief The type of a stage. */ -enum StageType { +enum class StageKind : int { /*! \brief A placeholder stage. */ kPlaceholder = 0, /*! \brief A compute stage. */ @@ -70,7 +70,7 @@ enum StageType { }; /*! \brief The type of compute location. */ -enum ComputeAtType { +enum class ComputeAtKind : int { /*! \brief Compute at root. */ kRoot = 0, /*! \brief Compute inlined. */ @@ -80,11 +80,11 @@ enum ComputeAtType { }; /*! \brief The type of an iterator. */ -enum IteratorType { +enum class IteratorKind : int { /*! \brief Spatial iterator. */ - kSpace = 0, + kSpatial = 0, /*! \brief Reduction iterator. */ - kReduce = 1, + kReduction = 1, /*! \brief Fused spatial and reduction iterator. */ kMixed = 2, /*! \brief Special iterator. (e.g. virtual root iterator) */ @@ -92,7 +92,7 @@ enum IteratorType { }; /*! \brief The type of an iterator's annotation. */ -enum IteratorAnnotation { +enum class IteratorAnnotation : int { /*! \brief This iterator has no annotation. */ kNone = 0, /*! \brief This iterator has been unrolled. */ @@ -126,7 +126,7 @@ class IteratorNode : public Object { /*! \brief The range of this iterator. */ Range range; /*! \brief The iterator type of this iterator. */ - IteratorType iter_type; + IteratorKind iter_kind; /*! \brief The annotation type of this iterator. */ IteratorAnnotation annotation; @@ -149,10 +149,10 @@ class Iterator : public ObjectRef { * \brief The constructor. * \param name The name of this iterator. * \param range The range of this iterator. - * \param iter_type The iterator type of this iterator. + * \param iter_kind The iterator type of this iterator. * \param annotation The annotation type of this iterator. */ - Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation); + Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation); TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); }; @@ -174,11 +174,11 @@ class StageNode : public Object { /*! \brief The operator of this stage */ te::Operation op; /*! \brief The type of this stage. */ - StageType op_type; + StageKind op_type; /*! \brief The iterators in this stage. */ Array iters; /*! \brief The compute location of this stage. */ - ComputeAtType compute_at; + ComputeAtKind compute_at; /*! \brief Other stage-level attributes. */ StageAttributes attrs; @@ -210,7 +210,7 @@ class Stage : public ObjectRef { * \param compute_at The compute at type of this op. * \param attrs Other stage-level attributes. */ - Stage(te::Operation op, StageType op_type, const Array& iters, ComputeAtType compute_at, + Stage(te::Operation op, StageKind op_type, const Array& iters, ComputeAtKind compute_at, StageAttributes attrs); TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); @@ -229,13 +229,16 @@ class StateNode : public Object { Array stages; /*! \brief History transformation steps. */ Array transform_steps; - /*! \brief Indicate whether this state has unfilled tile sizes. */ - bool complete; + /*! + * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all + * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. + */ + bool concrete; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("stages", &stages); v->Visit("transform_steps", &transform_steps); - v->Visit("complete", &complete); + v->Visit("concrete", &concrete); } static constexpr const char* _type_key = "ansor.State"; @@ -296,7 +299,7 @@ class State : public ObjectRef { * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. * \return The iterator results after split. */ - Array split(int stage_id, const Iterator& it, const Array& lengths, + Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); /*! * \brief Schedule primitive corresponds to te.fuse. @@ -340,8 +343,8 @@ class State : public ObjectRef { * \param inner_to_outer The split direction. * \return The iterator results after split. */ - Array DoSplitStepCommon(int stage_id, int iter_id, const Array& lengths, - bool inner_to_outer); + Array DoSplitStepCommon(int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer); }; } // namespace ansor @@ -358,7 +361,14 @@ struct hash<::tvm::ansor::State> { } }; -/*! \brief The equal_to function for ansor::State. */ +/*! + * \brief The equal_to function for ansor::State. + * We use the schedule result(its string format) of a state to check if two states are `euqal`. + * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two + * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts + * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result + * to split from outter to inner by factors [8, 16]) + */ template <> struct equal_to<::tvm::ansor::State> { bool operator()(const ::tvm::ansor::State& lhs, const ::tvm::ansor::State& rhs) const { diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 9e1e15dd0830..6cf9769e68bc 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -148,7 +148,7 @@ Array LocalRunnerNode::Run(const Array& inputs, /********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, - Array callbacks, int verbose, + Optional> callbacks, int verbose, int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); @@ -217,8 +217,10 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po } // Call callback functions - for (const auto& callback : callbacks) { - callback->Callback(policy, input_batch, result_batch); + if (callbacks) { + for (const auto& callback : callbacks.value()) { + callback->Callback(policy, input_batch, result_batch); + } } // Store result batch diff --git a/src/ansor/measure.h b/src/ansor/measure.h index e441c931a9e7..036022eac060 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -301,7 +301,8 @@ class LocalBuilder : public ProgramBuilder { public: /*! * \brief The constructor. - * \param timeout The timeout limit for each build. + * \param timeout The timeout limit (in second) for each build thread. + * This will be used in a wrapper of the multiprocessing.Process.join(). * \param n_parallel Number of threads used to build in parallel. * \param build_func The name of registered build function. */ @@ -338,7 +339,8 @@ class LocalRunner : public ProgramRunner { /*! * \brief The constructor. See the corresponding class in python/tvm/ansor/measure.py for more * detailed parameter explaination. - * \param timeout The timeout limit for each run. + * \param timeout The timeout limit (in second) for each run. + * This is used in a wrapper of the multiprocessing.Process.join(). * \param number Number of measure times. * \param repeat Number of repeat times in each measure. * \param min_repeat_ms The minimum duration of one repeat in milliseconds. @@ -369,7 +371,7 @@ class ProgramMeasurerNode : public Object { /*! \brief The ProgramRunner to measure each program. */ ProgramRunner runner; /*! \brief MeasureCallback to be called after each measure batch. */ - Array callbacks; + Optional> callbacks; /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */ int verbose; /*! \brief The number of max continuous error. */ @@ -420,8 +422,9 @@ class ProgramMeasurer : public ObjectRef { * \param verbose Verbosity level. 0 for silent, 1 to output information during program measuring. * \param max_continous_error The number of max continuous error. */ - ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Array callbacks, - int verbose, int max_continous_error = -1); + ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, + Optional> callbacks, int verbose, + int max_continous_error = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); }; diff --git a/src/ansor/record.cc b/src/ansor/measure_record.cc similarity index 87% rename from src/ansor/record.cc rename to src/ansor/measure_record.cc index e82ba66c8f3a..a959b00cc27f 100644 --- a/src/ansor/record.cc +++ b/src/ansor/measure_record.cc @@ -18,11 +18,11 @@ */ /*! - * \file ansor/record.cc + * \file ansor/measure_record.cc * \brief Json serialization format for dumping and loading tuning records. */ -#include "record.h" +#include "measure_record.h" #include #include @@ -51,6 +51,16 @@ inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& dat return out; } +inline std::vector IntArrayToVector( + const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x); + out.push_back(x.value()); + } + return out; +} + template <> struct Handler<::tvm::Array<::tvm::ansor::Stage>> { inline static void Write(dmlc::JSONWriter* writer, @@ -81,7 +91,7 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->extent.defined() ? ::tvm::ansor::GetIntImm(ps->extent) : 0); + writer->WriteArrayItem(ps->extent ? ::tvm::ansor::GetIntImm(ps->extent.value()) : 0); writer->WriteArrayItem(IntArrayToVector(ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { @@ -137,9 +147,9 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { s = reader->NextArrayItem(); CHECK(s); reader->Read(&inner_to_outer); - ::tvm::Array<::tvm::Integer> lengths; + ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; for (const auto& i : int_list) { - lengths.push_back(i); + lengths.push_back(::tvm::Integer(i)); } data->push_back(::tvm::ansor::SplitStep( stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); @@ -224,7 +234,7 @@ struct Handler<::tvm::ansor::MeasureInputNode> { bool s; auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>(); auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>(); - state_node->complete = true; + state_node->concrete = true; reader->BeginArray(); s = reader->NextArrayItem(); @@ -290,13 +300,13 @@ struct Handler<::tvm::ansor::MeasureResultNode> { namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(LogToFileNode); -TVM_REGISTER_OBJECT_TYPE(LogReaderNode); +TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); +TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) -LogToFile::LogToFile(String filename) { - auto node = make_object(); +RecordToFile::RecordToFile(String filename) { + auto node = make_object(); node->filename = std::move(filename); data_ = std::move(node); } @@ -334,22 +344,22 @@ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureRes } } -void LogToFileNode::Callback(const SearchPolicy& policy, const Array& inputs, - const Array& results) { +void RecordToFileNode::Callback(const SearchPolicy& policy, const Array& inputs, + const Array& results) { std::ofstream ofs(filename, std::ofstream::app); WriteMeasureRecords(&ofs, inputs, results); } -LogReader::LogReader(String filename) { - auto node = make_object(); +RecordReader::RecordReader(String filename) { + auto node = make_object(); node->filename = filename; node->infile.open(filename, std::ifstream::in); data_ = std::move(node); } -LogReaderNode::~LogReaderNode() { infile.close(); } +RecordReaderNode::~RecordReaderNode() { infile.close(); } -bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { +bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { std::string log_version; while (std::getline(infile, cur_line)) { @@ -364,8 +374,8 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { return false; } -std::pair, Array> LogReaderNode::ReadLines(int max_size, - int skip_size) { +std::pair, Array> RecordReaderNode::ReadLines(int max_size, + int skip_size) { auto inp = make_object(); auto res = make_object(); Array inputs; @@ -388,21 +398,21 @@ std::pair, Array> LogReaderNode::ReadLines(in return std::make_pair(inputs, results); } -TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const String& filename) { - return LogToFile(filename); +TVM_REGISTER_GLOBAL("ansor.RecordToFile").set_body_typed([](const String& filename) { + return RecordToFile(filename); }); -TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const String& filename) { - return LogReader(filename); +TVM_REGISTER_GLOBAL("ansor.RecordReader").set_body_typed([](const String& filename) { + return RecordReader(filename); }); -TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") - .set_body_typed([](LogReader reader, int size, int skip_size) { +TVM_REGISTER_GLOBAL("ansor.RecordReaderReadLines") + .set_body_typed([](RecordReader reader, int size, int skip_size) { const auto& res = reader->ReadLines(size, skip_size); return Array{res.first, res.second}; }); -TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext").set_body_typed([](LogReader reader) { +TVM_REGISTER_GLOBAL("ansor.RecordReaderReadNext").set_body_typed([](RecordReader reader) { auto inp = make_object(); auto res = make_object(); if (reader->ReadNext(inp.get(), res.get())) { @@ -412,7 +422,7 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext").set_body_typed([](LogReader reade } }); -TVM_REGISTER_GLOBAL("ansor.AppendMeasureRecordsToFile") +TVM_REGISTER_GLOBAL("ansor.SaveRecords") .set_body_typed([](String filename, Array in, Array res) { std::ofstream ofs(filename, std::ofstream::app); WriteMeasureRecords(&ofs, in, res); diff --git a/src/ansor/record.h b/src/ansor/measure_record.h similarity index 77% rename from src/ansor/record.h rename to src/ansor/measure_record.h index 0e26b6bdaf1e..1b6ed8f5bdba 100644 --- a/src/ansor/record.h +++ b/src/ansor/measure_record.h @@ -18,12 +18,12 @@ */ /*! - * \file ansor/record.h + * \file ansor/measure_record.h * \brief Json serialization format for dumping and loading tuning records. */ -#ifndef TVM_ANSOR_RECORD_H_ -#define TVM_ANSOR_RECORD_H_ +#ifndef TVM_ANSOR_MEASURE_RECORD_H_ +#define TVM_ANSOR_MEASURE_RECORD_H_ #include #include @@ -35,7 +35,7 @@ namespace tvm { namespace ansor { /*! \brief Callback for logging the input and results of measurements to file */ -class LogToFileNode : public MeasureCallbackNode { +class RecordToFileNode : public MeasureCallbackNode { public: /*! \brief File name for this callback to write log to. */ String filename; @@ -43,34 +43,34 @@ class LogToFileNode : public MeasureCallbackNode { void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; - static constexpr const char* _type_key = "ansor.LogToFile"; - TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); + static constexpr const char* _type_key = "ansor.RecordToFile"; + TVM_DECLARE_FINAL_OBJECT_INFO(RecordToFileNode, MeasureCallbackNode); }; /*! - * \brief Managed reference to LogToFileNode. - * \sa LogToFileNode + * \brief Managed reference to RecordToFileNode. + * \sa RecordToFileNode */ -class LogToFile : public MeasureCallback { +class RecordToFile : public MeasureCallback { public: /*! * \brief The constructor. * \param filename File name for this callback to write log. */ - explicit LogToFile(String filename); + explicit RecordToFile(String filename); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordToFile, MeasureCallback, RecordToFileNode); }; /*! \brief Log reader to load step logs from a file.*/ -class LogReaderNode : public Object { +class RecordReaderNode : public Object { public: /*! \brief File name for this reader to load log from. */ String filename; /*! \brief The reading file stream. */ std::ifstream infile; - ~LogReaderNode(); + ~RecordReaderNode(); /*! * \brief Read next line in the log file. @@ -88,8 +88,8 @@ class LogReaderNode : public Object { std::pair, Array> ReadLines(int max_size = -1, int skip_size = 0); - static constexpr const char* _type_key = "ansor.LogReader"; - TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); + static constexpr const char* _type_key = "ansor.RecordReader"; + TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); private: /*! \brief A string object to store the next line. */ @@ -97,18 +97,18 @@ class LogReaderNode : public Object { }; /*! - * \brief Managed reference to LogReaderNode. - * \sa LogReaderNode + * \brief Managed reference to RecordReaderNode. + * \sa RecordReaderNode */ -class LogReader : public ObjectRef { +class RecordReader : public ObjectRef { public: /*! * \brief The constructor. * \param filename File name for this callback to write log. */ - explicit LogReader(String filename); + explicit RecordReader(String filename); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogReader, ObjectRef, LogReaderNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RecordReader, ObjectRef, RecordReaderNode); }; /*! @@ -133,4 +133,4 @@ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureRes } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_RECORD_H_ +#endif // TVM_ANSOR_MEASURE_RECORD_H_ diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index 659e0441d940..f897b29615d8 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -35,7 +35,7 @@ TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) { + Optional> pre_search_callbacks) { cur_task = task; // Run pre_search_callbacks before the search process diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index 094e71913823..ce8ac78fc2fc 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -42,7 +42,7 @@ class EmptyPolicyNode : public SearchPolicyNode { public: State Search(SearchTask task, int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) final; + Optional> pre_search_callbacks) final; static constexpr const char* _type_key = "ansor.EmptyPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 0b62efda79cc..67edc53be009 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -33,7 +33,7 @@ TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); void SearchPolicyNode::RunCallbacks(const Optional>& callbacks) { - if (callbacks.defined()) { + if (callbacks) { for (const auto& callback : callbacks.value()) { callback->Callback(this); } diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 2d551f1c5ce1..f70fc265e56c 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -122,7 +122,7 @@ class SearchPolicyNode : public Object { */ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) = 0; + Optional> pre_search_callbacks) = 0; /*! * \brief Call SearchCallback with the current SearchPolicyNode @@ -136,10 +136,16 @@ class SearchPolicyNode : public Object { protected: /*! * \brief The set of already measured states. - * We store the string format for redundancy check. + * During the schedule search process, we may generate `equal states` through different search + * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different + * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512 + * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can + * get a same result to split from outter to inner by factors [8, 16]) + * We store the string format of a state for redundancy check. This is used to make sure a + * measured state will never be measured again. */ std::unordered_set measured_states_set_; - /*! \brief The array of already measured states. */ + /*! \brief The array of already measured states. This can be used in evolutionary search. */ std::vector measured_states_vector_; /*! \brief The throughputs of already measured states */ std::vector measured_states_throughputs_; diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 090f6c58f175..3de986bb81fd 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -54,14 +54,14 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target } SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, - Target target_host, HardwareParams hardware_params) { + Target target_host, Optional hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); node->target = std::move(target); node->target_host = std::move(target_host); - if (hardware_params.defined()) { - node->hardware_params = std::move(hardware_params); + if (hardware_params) { + node->hardware_params = std::move(hardware_params.value()); } else { node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); @@ -76,7 +76,7 @@ TVM_REGISTER_GLOBAL("ansor.HardwareParams") TVM_REGISTER_GLOBAL("ansor.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, - Target target_host, HardwareParams hardware_params) { + Target target_host, Optional hardware_params) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); }); diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index ba418049e655..fb9a6e098023 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -141,7 +141,7 @@ class SearchTask : public ObjectRef { * \param hardware_params Hardware parameters used in this search task. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - HardwareParams hardware_params); + Optional hardware_params); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index d3b5b39750c1..f096e63a4e54 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -41,7 +41,7 @@ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); node->stage_id = stage_id; for (const auto& x : after_ids) { - CHECK(x.defined() && x->IsInstance()); + CHECK(x->IsInstance()); } node->after_ids = after_ids; data_ = std::move(node); @@ -84,8 +84,8 @@ String ReorderStepNode::PrintAsPythonAPI(Array* stages, /********** Split **********/ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, - int stage_id, int iter_id, const Array& lengths, - bool inner_to_outer) { + int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer) { auto stage = (*stages)[stage_id]; const Array& axes = stage_to_axes->at(stage); @@ -94,7 +94,7 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st IterVar outer = axes[iter_id], inner; for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { IterVar to_split = outer; - stage.split(to_split, lengths[i], &outer, &inner); + stage.split(to_split, lengths[i].value(), &outer, &inner); outs.push_back(inner); } outs.push_back(outer); @@ -102,7 +102,7 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st IterVar outer, inner = axes[iter_id]; for (size_t i = 0; i < lengths.size(); i++) { IterVar to_split = inner; - stage.split_by_nparts(to_split, lengths[i], &outer, &inner); + stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner); outs.push_back(outer); } outs.push_back(inner); @@ -127,7 +127,8 @@ Array ApplySplitToSchedule(Array* stages, StageToAxesMap* st } String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, - int iter_id, const Array& lengths, bool inner_to_outer) { + int iter_id, const Array>& lengths, + bool inner_to_outer) { const auto& stage = (*stages)[stage_id]; auto to_split = stage_to_axes->at(stage)[iter_id]; const auto& func_name = CleanName(stage->op->name); @@ -156,13 +157,13 @@ String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_ return ss.str(); } -SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, - bool inner_to_outer) { +SplitStep::SplitStep(int stage_id, int iter_id, Optional extent, + const Array>& lengths, bool inner_to_outer) { auto node = make_object(); node->stage_id = stage_id; // Extent can be a unreducible expression in some special cases - if (extent->IsInstance()) { - node->extent = tvm::Downcast(extent); + if (extent && extent.value()->IsInstance()) { + node->extent = tvm::Downcast(extent.value()); } node->iter_id = iter_id; node->lengths = lengths; @@ -185,7 +186,7 @@ FuseStep::FuseStep(int stage_id, const Array& fused_ids) { auto node = make_object(); node->stage_id = stage_id; for (const auto& x : fused_ids) { - CHECK(x.defined() && x->IsInstance()); + CHECK(x->IsInstance()); } node->fused_ids = fused_ids; data_ = std::move(node); diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 3d40f35d184d..4feec4355c07 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -21,8 +21,8 @@ * \file ansor/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform * step. The implementation of each step consists of 2 parts: - * - transform_step.cc: How each step interact with TE and TE's schedule primitives - * - loop_state.cc: How each step reflect on LoopState + * - transform_step.cc: How each step interacts with TE and TE's schedule primitives + * - loop_state.cc: How each step updates LoopState * * \note To add a new transform step: * Take fuse step for example: @@ -128,9 +128,9 @@ class SplitStepNode : public StepNode { /*! \brief The id of the iter to split. */ int iter_id; /*! \brief The extent length of the axis to split. */ - Integer extent; + Optional extent; /*! \brief The split factors. */ - Array lengths; + Array> lengths; /*! * \brief If true, the `lengths` denote the lengths of iterators * from inner level to outer level @@ -172,8 +172,8 @@ class SplitStep : public Step { * \param lengths The multiple split factors. Can be None to be filled by search policy. * \param inner_to_outer The split direction. */ - SplitStep(int stage_id, int iter_id, PrimExpr extent, const Array& lengths, - bool inner_to_outer); + SplitStep(int stage_id, int iter_id, Optional extent, + const Array>& lengths, bool inner_to_outer); TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index 5d100025bf1e..3820b7f0d168 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -36,9 +36,9 @@ def test_record(): res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) with tempfile.NamedTemporaryFile() as fp: - ansor.record.append_measure_records_to_file(fp.name, [inp], [res]) + ansor.save_records(fp.name, [inp], [res]) - log_reader = ansor.record.LogReader(fp.name) + log_reader = ansor.RecordReader(fp.name) inputs, results = log_reader.read_lines() assert len(inputs) == 1 diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 8922fd722690..202c12e4afbb 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -45,11 +45,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tuning_options = ansor.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, verbose=0, - measure_callbacks=[ansor.LogToFile(log_file)], + measure_callbacks=[ansor.RecordToFile(log_file)], pre_search_callbacks=pre_search_callbacks) sch, args = ansor.auto_schedule(task, target, search_policy=search_policy, tuning_options=tuning_options) - inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) + inp, res = ansor.load_best(log_file, workload_key, target) print("==== Python Code ====") print(dag.print_python_code_from_state(inp.state)) From 28a7b8f14ab63ac45225a93fa4e6b07f43695dae Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 9 Jul 2020 15:35:10 +0800 Subject: [PATCH 70/78] Update --- python/tvm/ansor/auto_schedule.py | 10 +++--- python/tvm/ansor/loop_state.py | 2 +- python/tvm/ansor/measure.py | 34 +++++++++---------- python/tvm/ansor/workload_registry.py | 32 +++++++++++------ src/ansor/auto_schedule.cc | 4 +-- src/ansor/auto_schedule.h | 13 ++++--- src/ansor/compute_dag.cc | 2 +- src/ansor/measure.cc | 12 +++---- src/ansor/measure.h | 23 +++++++------ src/ansor/measure_record.cc | 6 ++-- src/ansor/measure_record.h | 2 +- src/ansor/search_policy/empty_policy.cc | 2 +- src/ansor/search_policy/empty_policy.h | 2 +- src/ansor/search_policy/search_policy.cc | 2 +- src/ansor/search_policy/search_policy.h | 9 ++--- src/ansor/search_task.cc | 2 +- src/ansor/utils.h | 6 ++-- tests/python/unittest/test_ansor_common.py | 9 +++++ .../unittest/test_ansor_search_policy.py | 18 +++++++--- 19 files changed, 112 insertions(+), 78 deletions(-) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 1f5bdf8f14d9..70a05bfc0cf7 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -113,8 +113,8 @@ class TuningOptions(Object): The number of schedules to be measured at each search round. The whole schedule search process will try a total number of `num_measure_trials` in several rounds. - verbose: int = 1 - Verbosity level. 0 for silent, 1 to output information during schedule search. + verbose: boolean = True + Verbosity level. False for silent, True to output information during schedule search. builder: Union[ProgramBuilder, str] = 'local' ProgramBuilder which builds the program. runner: Union[ProgramRunner, str] = 'local' @@ -131,7 +131,7 @@ class TuningOptions(Object): TODO(jcf94): Add these implementation in later PRs. """ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64, - verbose=1, builder='local', runner='local', measure_callbacks=None, + verbose=True, builder='local', runner='local', measure_callbacks=None, pre_search_callbacks=None): if isinstance(builder, str): if builder == 'local': @@ -153,8 +153,8 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r self.__init_handle_by_constructor__( _ffi_api.TuningOptions, num_measure_trials, early_stopping if early_stopping else -1, - num_measures_per_round, - verbose, builder, runner, measure_callbacks, pre_search_callbacks) + num_measures_per_round, verbose, builder, runner, measure_callbacks, + pre_search_callbacks) def auto_schedule(task, target, target_host=None, search_policy='default', diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 324cc20d48ed..610392a22523 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -138,7 +138,7 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): The iterator to be split. lengths: List[int] The multiple split factors. Can be None to be filled by search policy. - inner_to_outer: bool = True + inner_to_outer: boolean = True Whether the factor go from inner to outer, or from outer to inner. Returns diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 2fe1169ad3fa..25170ea7ebfa 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -128,15 +128,15 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): class ProgramBuilder(Object): """ The base class of ProgramBuilders. """ - def build(self, measure_inputs, verbose=1): + def build(self, measure_inputs, verbose=True): """ Build programs and return results. Parameters ---------- measure_inputs : List[MeasureInput] A List of MeasureInput. - verbose : int = 1 - Verbosity level. 0 for silent, 1 to output information during program building. + verbose : boolean = True + Verbosity level. False for silent, True to output information during program building. Returns ------- @@ -149,7 +149,7 @@ def build(self, measure_inputs, verbose=1): class ProgramRunner(Object): """ The base class of ProgramRunners. """ - def run(self, measure_inputs, build_results, verbose=1): + def run(self, measure_inputs, build_results, verbose=True): """ Run measurement and return results. Parameters @@ -158,8 +158,8 @@ def run(self, measure_inputs, build_results, verbose=1): A List of MeasureInput. build_results : List[BuildResult] A List of BuildResult to be ran. - verbose : int = 1 - Verbosity level. 0 for silent, 1 to output information during program running. + verbose : boolean = True + Verbosity level. False for silent, True to output information during program running. Returns ------- @@ -318,7 +318,7 @@ def timed_func(): else: filename = "" - if verbose == 1: + if verbose: if error_no == MeasureErrorNo.NO_ERROR: print(".", end="") else: @@ -327,7 +327,7 @@ def timed_func(): res = call_func_with_timeout(timeout, timed_func) if isinstance(res, TimeoutError): - if verbose == 1: + if verbose: print(".T", end="") # Build timeout res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout @@ -335,7 +335,7 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") -def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=1): +def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=True): """ Build function of LocalBuilder to build the MeasureInputs to runnable modules. @@ -350,8 +350,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo Number of threads used to build in parallel. build_func : str = 'default' The name of build function to process the built module. - verbose : int = 1 - Verbosity level. 0 for silent, 1 to output information during program building. + verbose : boolean = True + Verbosity level. False for silent, True to output information during program building. Returns ------- @@ -378,7 +378,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo @tvm._ffi.register_func("ansor.local_runner.run") def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, - verbose=1): + verbose=True): """ Run function of LocalRunner to test the performance of the input BuildResults. @@ -409,8 +409,8 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo will be automatically increased. cooldown_interval : float = 0.0 The cool down interval between two measurements. - verbose : int = 1 - Verbosity level. 0 for silent, 1 to output information during program measuring. + verbose : boolean = True + Verbosity level. False for silent, True to output information during program measuring. Returns ------- @@ -450,7 +450,7 @@ def timed_func(inp, build_res): toc = time.time() time.sleep(cooldown_interval) - if verbose == 1: + if verbose: if error_no == MeasureErrorNo.NO_ERROR: print("*", end="") else: @@ -468,13 +468,13 @@ def timed_func(inp, build_res): res = call_func_with_timeout( timeout, timed_func, args=(inp, build_res)) if isinstance(res, TimeoutError): - if verbose == 1: + if verbose: print("*T", end="") # Run timeout res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \ build_res.time_cost + timeout, time.time() measure_results.append(MeasureResult(*res)) - if verbose == 1: + if verbose: print("") return measure_results diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index afaf31d5d874..450bb64cbc66 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -38,7 +38,7 @@ WORKLOAD_FUNC_REGISTRY = {} -def register_workload(func): +def register_workload(func_name, f=None, override=False): """ Register a function that generates a certain workload. The input function should take hashable and jsonable arguments @@ -46,8 +46,12 @@ def register_workload(func): Parameters ---------- - func : Function - The generation function that returns the compute declaration Tensors. + func_name : Union[Function, str] + The generation function that returns the compute declaration Tensors or its function name. + f : Optional[Function] + The generation function to be registered. + override : boolean = False + Whether override existing entry. Examples -------- @@ -60,14 +64,22 @@ def matmul(N, M, K): return [A, B, C] """ global WORKLOAD_FUNC_REGISTRY - assert callable(func) - func_name = get_func_name(func) - if func_name in WORKLOAD_FUNC_REGISTRY: - raise RuntimeError('%s has been registered already' % func_name) - - WORKLOAD_FUNC_REGISTRY[func_name] = func - return func + if callable(func_name): + f = func_name + func_name = get_func_name(f) + if not isinstance(func_name, str): + raise ValueError("expect string function name") + + def register(myf): + """internal register function""" + if func_name in WORKLOAD_FUNC_REGISTRY and not override: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = myf + return myf + if f: + return register(f) + return register def make_workload_key(func, args): diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 43e7d9f85ccc..184989ca1db8 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -32,7 +32,7 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(TuningOptionsNode); TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, - int verbose, ProgramBuilder builder, ProgramRunner runner, + bool verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, Optional> pre_search_callbacks) { auto node = make_object(); @@ -63,7 +63,7 @@ std::pair> AutoSchedule(SearchTask task, SearchP TVM_REGISTER_GLOBAL("ansor.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, - int verbose, ProgramBuilder builder, ProgramRunner runner, + bool verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, Optional> pre_search_callbacks) { return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 9e705cd350c4..84979f040cac 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -44,8 +44,11 @@ class TuningOptionsNode : public Object { int early_stopping; /*! \brief The number of programs to be measured at each search round. */ int num_measures_per_round; - /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */ - int verbose; + /*! + * \brief Verbosity level. + * False for silent, true to output information during schedule searching. + */ + bool verbose; /*! \brief ProgramBuilder which builds the program */ ProgramBuilder builder; /*! \brief ProgramRunner which runs the program and measure time costs */ @@ -81,15 +84,15 @@ class TuningOptions : public ObjectRef { * \param num_measure_trials Number of total measurement trials. * \param early_stopping Stops early the tuning if no improvement after n measurements. * \param num_measures_per_round The number of programs to be measured at each search round. - * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule + * \param verbose Verbosity level. False for silent, true to output information during schedule * search. * \param builder ProgramBuilder which builds the program. * \param runner ProgramRunner which runs the program and measure time costs. * \param measure_callbacks MeasureCallback functions to be called after each measure batch. * \param pre_search_callbacks SearchCallback functions to be called before schedule search. */ - TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, - ProgramBuilder builder, ProgramRunner runner, + TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, + bool verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, Optional> pre_search_callbacks); diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 98b4ed3c7df7..35c8daafad29 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -394,7 +394,7 @@ State ComputeDAG::InferBound(const State& state) const { } pstate->stages.Set( - i, Stage(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs)); + i, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); } return ret_state; diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 6cf9769e68bc..2a2f65c0590c 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -110,7 +110,7 @@ LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func data_ = std::move(node); } -Array LocalBuilderNode::Build(const Array& inputs, int verbose) { +Array LocalBuilderNode::Build(const Array& inputs, bool verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; @@ -134,7 +134,7 @@ LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, } Array LocalRunnerNode::Run(const Array& inputs, - const Array& build_results, int verbose) { + const Array& build_results, bool verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { Array results = (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose); @@ -148,7 +148,7 @@ Array LocalRunnerNode::Run(const Array& inputs, /********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, - Optional> callbacks, int verbose, + Optional> callbacks, bool verbose, int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); @@ -261,7 +261,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { + if (node->error_no == static_cast(MeasureErrorNO::kNoError)) { p->stream << "MeasureResult(cost:["; auto old_config = p->stream.precision(4); for (size_t i = 0; i < node->costs.size(); ++i) { @@ -312,12 +312,12 @@ TVM_REGISTER_GLOBAL("ansor.MeasureResult") TVM_REGISTER_GLOBAL("ansor.ProgramBuilderBuild") .set_body_typed([](const ProgramBuilder& builder, const Array& inputs, - int verbose) { return builder->Build(inputs, verbose); }); + bool verbose) { return builder->Build(inputs, verbose); }); TVM_REGISTER_GLOBAL("ansor.ProgramRunnerRun") .set_body_typed([](const ProgramRunner& runner, const Array& inputs, const Array& build_results, - int verbose) { return runner->Run(inputs, build_results, verbose); }); + bool verbose) { return runner->Run(inputs, build_results, verbose); }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") .set_body_typed([](int timeout, int n_parallel, const String& build_func) { diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 036022eac060..0b8d2dd5649e 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -41,7 +41,7 @@ class MeasureInput; class MeasureResult; /*! \brief The error code of one measurement */ -enum MeasureErrorNO { +enum class MeasureErrorNO : int { /*! \brief No error. */ kNoError = 0, /*! \brief Errors happen when apply transform steps from init state. */ @@ -232,10 +232,11 @@ class ProgramBuilderNode : public Object { /*! * \brief Build programs and return results. * \param inputs An Array of MeasureInput. - * \param verbose Verbosity level. 0 for silent, 1 to output information during program building. + * \param verbose Verbosity level. False for silent, true to output information during program + * building. * \return An Array of MeasureResult. */ - virtual Array Build(const Array& inputs, int verbose) = 0; + virtual Array Build(const Array& inputs, bool verbose) = 0; static constexpr const char* _type_key = "ansor.ProgramBuilder"; TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object); @@ -260,11 +261,12 @@ class ProgramRunnerNode : public Object { * \brief Run measurement and return results. * \param inputs An Array of MeasureInput. * \param build_results An Array of BuildResult. - * \param verbose Verbosity level. 0 for silent, 1 to output information during program running. + * \param verbose Verbosity level. False for silent, true to output information during program + * running. * \return An Array of MeasureResult. */ virtual Array Run(const Array& inputs, - const Array& build_results, int verbose) = 0; + const Array& build_results, bool verbose) = 0; static constexpr const char* _type_key = "ansor.ProgramRunner"; TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object); @@ -287,7 +289,7 @@ class LocalBuilderNode : public ProgramBuilderNode { /*! \brief Build function. */ String build_func; - Array Build(const Array& inputs, int verbose) final; + Array Build(const Array& inputs, bool verbose) final; static constexpr const char* _type_key = "ansor.LocalBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode); @@ -324,7 +326,7 @@ class LocalRunnerNode : public ProgramRunnerNode { double cooldown_interval; Array Run(const Array& inputs, - const Array& build_results, int verbose) final; + const Array& build_results, bool verbose) final; static constexpr const char* _type_key = "ansor.LocalRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode); @@ -373,7 +375,7 @@ class ProgramMeasurerNode : public Object { /*! \brief MeasureCallback to be called after each measure batch. */ Optional> callbacks; /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */ - int verbose; + bool verbose; /*! \brief The number of max continuous error. */ int max_continous_error; @@ -419,11 +421,12 @@ class ProgramMeasurer : public ObjectRef { * \param builder The ProgramBuilder to build each program. * \param runner The ProgramRunner to measure each program. * \param callbacks MeasureCallback to be called after each measure batch. - * \param verbose Verbosity level. 0 for silent, 1 to output information during program measuring. + * \param verbose Verbosity level. False for silent, true to output information during program + * measuring. * \param max_continous_error The number of max continuous error. */ ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, - Optional> callbacks, int verbose, + Optional> callbacks, bool verbose, int max_continous_error = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); diff --git a/src/ansor/measure_record.cc b/src/ansor/measure_record.cc index a959b00cc27f..7a3573fa18a3 100644 --- a/src/ansor/measure_record.cc +++ b/src/ansor/measure_record.cc @@ -362,12 +362,12 @@ RecordReaderNode::~RecordReaderNode() { infile.close(); } bool RecordReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { std::string log_version; - while (std::getline(infile, cur_line)) { - if (cur_line[0] == '#' || cur_line[0] == ' ') { + while (std::getline(infile, cur_line_)) { + if (cur_line_[0] == '#' || cur_line_[0] == ' ') { // skip comment lines begin with '#' or ' ' continue; } - ReadMeasureRecord(cur_line, inp, res, &log_version); + ReadMeasureRecord(cur_line_, inp, res, &log_version); return true; } diff --git a/src/ansor/measure_record.h b/src/ansor/measure_record.h index 1b6ed8f5bdba..f14dabfd8fa3 100644 --- a/src/ansor/measure_record.h +++ b/src/ansor/measure_record.h @@ -93,7 +93,7 @@ class RecordReaderNode : public Object { private: /*! \brief A string object to store the next line. */ - std::string cur_line; + std::string cur_line_; }; /*! diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index f897b29615d8..ea3325bd09e1 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -34,7 +34,7 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, int verbose, ProgramMeasurer measurer, + int num_measures_per_round, bool verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) { cur_task = task; diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index ce8ac78fc2fc..3757ec281836 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -41,7 +41,7 @@ namespace ansor { class EmptyPolicyNode : public SearchPolicyNode { public: State Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, int verbose, ProgramMeasurer measurer, + int num_measures_per_round, bool verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) final; static constexpr const char* _type_key = "ansor.EmptyPolicy"; diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 67edc53be009..138d9f10639c 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -49,7 +49,7 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; }); TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") - .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; }); + .set_body_typed([](SearchPolicy policy, bool verbose) { policy->verbose = verbose; }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f70fc265e56c..0edd07bd0ad6 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -99,9 +99,9 @@ class SearchPolicyNode : public Object { SearchTask cur_task; /*! * \brief Verbose level to control the screen output during schedule search. - * 0 for silent, 1 to output information. + * False for silent, true to output state & measure information during search process. */ - int verbose; + bool verbose; void VisitAttrs(AttrVisitor* v) { v->Visit("cur_task", &cur_task); @@ -115,13 +115,14 @@ class SearchPolicyNode : public Object { * \param num_measure_trials Total schedules to be tried during this search. * \param early_stopping Early stop if no better schedule is found. * \param num_measures_per_round Max measure batch in one search round. - * \param verbose Verbose level. 0 for silent, 1 to output information during schedule search. + * \param verbose Verbose level. False for silent, true to output information during schedule + * search. * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. * \param pre_search_callbacks SearchCallback to be called before schedule search. * \return The best state get. */ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, int verbose, ProgramMeasurer measurer, + int num_measures_per_round, bool verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) = 0; /*! diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 3de986bb81fd..633d99fa41a8 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -61,7 +61,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe node->target = std::move(target); node->target_host = std::move(target_host); if (hardware_params) { - node->hardware_params = std::move(hardware_params.value()); + node->hardware_params = hardware_params.value(); } else { node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 4b0d0a0b180a..cd2d32344899 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -163,9 +163,7 @@ NullStream& operator<<(NullStream& os, const T& value) { } /*! \brief Get std cout with verbose control */ -inline std::ostream& StdCout(int verbose) { - return verbose == 1 ? std::cout : NullStream::Global(); -} +inline std::ostream& StdCout(bool verbose) { return verbose ? std::cout : NullStream::Global(); } /*! \brief Print multiple chars */ inline std::string Chars(const char& str, int times) { @@ -177,7 +175,7 @@ inline std::string Chars(const char& str, int times) { } /*! \brief Print a title */ -inline void PrintTitle(const std::string& title, int verbose) { +inline void PrintTitle(const std::string& title, bool verbose) { StdCout(verbose) << Chars('-', 60) << "\n" << Chars('-', 25) << " [ " << title << " ]\n" << Chars('-', 60) << std::endl; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 8c3895128849..9f4e62466095 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -32,6 +32,15 @@ def matmul_ansor_test(N, M, K): return [A, B, C] +@ansor.register_workload("matmul_ansor_test_rename_1") +def matmul_ansor_test_rename_0(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + + def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='Data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 202c12e4afbb..40ea1112671e 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -26,14 +26,14 @@ from test_ansor_common import matmul_ansor_test, PropagatingThread -def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', - cost_model=None, num_measure_trials=2, params=None, +def search_common(workload=matmul_ansor_test, target="llvm", seed=random.randint(1, 1 << 30), + runner='local', cost_model=None, num_measure_trials=2, params=None, pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) N = 128 - workload_key = ansor.make_workload_key(matmul_ansor_test, (N, N, N)) + workload_key = ansor.make_workload_key(workload, (N, N, N)) dag = ansor.ComputeDAG(workload_key) target = tvm.target.create(target) task = ansor.SearchTask(dag, workload_key, target) @@ -73,7 +73,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' print() -def test_search_basic(): +def test_workload_registry_search_basic(): if not tvm.runtime.enabled("llvm"): return # wrap the search in a new thread to avoid the conflict @@ -81,6 +81,14 @@ def test_search_basic(): t = PropagatingThread(target=search_common, kwargs={'seed': 944563397}) t.start() t.join() + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'workload': "matmul_ansor_test"}) + t.start() + t.join() + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'workload': "matmul_ansor_test_rename_1"}) + t.start() + t.join() if __name__ == "__main__": - test_search_basic() + test_workload_registry_search_basic() From 1360b1b22c4a33ae6b8468cc52d59ab9507c2736 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 9 Jul 2020 17:37:16 +0800 Subject: [PATCH 71/78] Update --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/_ffi_api.py | 2 +- python/tvm/ansor/auto_schedule.py | 33 +++++++------------ python/tvm/ansor/compute_dag.py | 4 +-- .../unittest/test_ansor_search_policy.py | 11 +++---- 5 files changed, 20 insertions(+), 32 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 5a2f210938f7..216bfe25c89f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-import, redefined-builtin -"""Namespace for Ansor auto-scheduler""" +""" Namespace for Ansor auto-scheduler. """ from . import compute_dag from . import measure diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py index e7b8a59eb83b..622c6f6ea43d 100644 --- a/python/tvm/ansor/_ffi_api.py +++ b/python/tvm/ansor/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Register FFI APIs from C++ for the namespace tvm.ansor""" +""" Register FFI APIs from C++ for the namespace tvm.ansor. """ import tvm._ffi diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 70a05bfc0cf7..0ae6ede8a168 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -30,7 +30,6 @@ import tvm._ffi from tvm.runtime import Object -from .compute_dag import ComputeDAG from .measure import LocalBuilder, LocalRunner from . import _ffi_api @@ -157,8 +156,7 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r pre_search_callbacks) -def auto_schedule(task, target, target_host=None, search_policy='default', - hardware_params=None, tuning_options=None): +def auto_schedule(task, search_policy='default', tuning_options=None): """ Do auto scheduling for a computation declaration. The task parameter can be a `string` as workload_key, or directly @@ -166,16 +164,10 @@ def auto_schedule(task, target, target_host=None, search_policy='default', Parameters ---------- - task : Union[SearchTask, str] - The SearchTask or workload key for the computation declaration. - target : tvm.target.Target - The target device of this schedule search. - target_host : Optional[tvm.target.Target] - The target host device of this schedule search. + task : SearchTask + The SearchTask for the computation declaration. search_policy : Union[SearchPolicy, str] = 'default' The search policy to be used for schedule search. - hardware_params : Optional[HardwareParams] - The hardware parameters of this schedule search. tuning_options : Optional[TuningOptions] Tuning and measurement options. @@ -183,6 +175,10 @@ def auto_schedule(task, target, target_host=None, search_policy='default', ------- A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. """ + if not isinstance(task, SearchTask): + raise ValueError("Invalid task: " + task + + " . `ansor.auto_schedule` expects a SearchTask.") + if isinstance(search_policy, str): if search_policy == 'default': # TODO(jcf94): This is an example policy for minimum system, will be upgrated to @@ -190,15 +186,10 @@ def auto_schedule(task, target, target_host=None, search_policy='default', search_policy = EmptyPolicy() else: raise ValueError("Invalid search policy: " + search_policy) + elif not isinstance(search_policy, SearchPolicy): + raise ValueError("Invalid search policy: " + search_policy + + " . `ansor.auto_schedule` expects a SearchPolicy or a string.") - tuning_options = tuning_options if tuning_options else TuningOptions() - - if isinstance(task, str): - dag = ComputeDAG(task) - task = SearchTask(dag, task, target, target_host, hardware_params) - elif not isinstance(task, SearchTask): - raise ValueError("Invalid task: " + task + - " . `ansor.auto_schedule` expects a `str` or `SearchTask`.") - - sch, tensors = _ffi_api.AutoSchedule(task, search_policy, tuning_options) + sch, tensors = _ffi_api.AutoSchedule(task, search_policy, + tuning_options if tuning_options else TuningOptions()) return sch, tensors diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index f52c3e2d3192..587225affa84 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -59,7 +59,7 @@ def __init__(self, compute): raise ValueError("The input of ComputeDAG should be a list of Tensor") else: raise ValueError("Invalid compute: " + compute + - " . `ComputeDAG` expects a string or list of Tensor") + " . ComputeDAG expects a string or list of Tensor") self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) def get_init_state(self): @@ -112,7 +112,7 @@ def infer_bound_from_state(self, state): """ Infer and fill the bound of all iterators of a state. - The states can lose complete bound information after some transform steps + The states may lose complete bound information after some transform steps (e.g., compute_at). We can call this function to infer and fill all the bound information. This function calls TVM InferBound pass internally to get the bound. diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 40ea1112671e..fd07c7aa48f4 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -26,9 +26,9 @@ from test_ansor_common import matmul_ansor_test, PropagatingThread -def search_common(workload=matmul_ansor_test, target="llvm", seed=random.randint(1, 1 << 30), - runner='local', cost_model=None, num_measure_trials=2, params=None, - pre_search_callbacks=None): +def search_common(workload=matmul_ansor_test, target="llvm", search_policy = ansor.EmptyPolicy(), + seed=random.randint(1, 1 << 30), runner='local', cost_model=None, + num_measure_trials=2, params=None, pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) @@ -41,14 +41,11 @@ def search_common(workload=matmul_ansor_test, target="llvm", seed=random.randint with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.EmptyPolicy() - # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tuning_options = ansor.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, verbose=0, measure_callbacks=[ansor.RecordToFile(log_file)], pre_search_callbacks=pre_search_callbacks) - sch, args = ansor.auto_schedule(task, target, search_policy=search_policy, - tuning_options=tuning_options) + sch, args = ansor.auto_schedule(task, search_policy, tuning_options) inp, res = ansor.load_best(log_file, workload_key, target) print("==== Python Code ====") From 52afe74f1765b24f7885427805d5e7bae6beb89c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 11 Jul 2020 20:15:55 +0800 Subject: [PATCH 72/78] Rename ansor namespace to auto_schedule --- CMakeLists.txt | 2 +- .../tvm/{ansor => auto_schedule}/__init__.py | 2 +- .../tvm/{ansor => auto_schedule}/_ffi_api.py | 4 +- .../{ansor => auto_schedule}/auto_schedule.py | 28 +++--- .../{ansor => auto_schedule}/compute_dag.py | 12 +-- .../{ansor => auto_schedule}/loop_state.py | 6 +- .../tvm/{ansor => auto_schedule}/measure.py | 20 ++--- .../measure_record.py | 10 +-- python/tvm/{ansor => auto_schedule}/utils.py | 2 +- .../workload_registry.py | 8 +- src/{ansor => auto_schedule}/auto_schedule.cc | 12 +-- src/{ansor => auto_schedule}/auto_schedule.h | 16 ++-- src/{ansor => auto_schedule}/compute_dag.cc | 18 ++-- src/{ansor => auto_schedule}/compute_dag.h | 24 +++--- src/{ansor => auto_schedule}/loop_state.cc | 16 ++-- src/{ansor => auto_schedule}/loop_state.h | 31 +++---- src/{ansor => auto_schedule}/measure.cc | 28 +++--- src/{ansor => auto_schedule}/measure.h | 34 ++++---- .../measure_record.cc | 85 ++++++++++--------- src/{ansor => auto_schedule}/measure_record.h | 16 ++-- .../search_policy/empty_policy.cc | 8 +- .../search_policy/empty_policy.h | 14 +-- .../search_policy/search_policy.cc | 12 +-- .../search_policy/search_policy.h | 18 ++-- src/{ansor => auto_schedule}/search_task.cc | 10 +-- src/{ansor => auto_schedule}/search_task.h | 16 ++-- .../transform_step.cc | 6 +- src/{ansor => auto_schedule}/transform_step.h | 22 ++--- src/{ansor => auto_schedule}/utils.cc | 6 +- src/{ansor => auto_schedule}/utils.h | 12 +-- ...common.py => test_auto_schedule_common.py} | 16 ++-- ...g.py => test_auto_schedule_compute_dag.py} | 4 +- ...te.py => test_auto_schedule_loop_state.py} | 8 +- ...asure.py => test_auto_schedule_measure.py} | 22 ++--- ...py => test_auto_schedule_search_policy.py} | 24 +++--- 35 files changed, 290 insertions(+), 282 deletions(-) rename python/tvm/{ansor => auto_schedule}/__init__.py (96%) rename python/tvm/{ansor => auto_schedule}/_ffi_api.py (87%) rename python/tvm/{ansor => auto_schedule}/auto_schedule.py (89%) rename python/tvm/{ansor => auto_schedule}/compute_dag.py (92%) rename python/tvm/{ansor => auto_schedule}/loop_state.py (98%) rename python/tvm/{ansor => auto_schedule}/measure.py (96%) rename python/tvm/{ansor => auto_schedule}/measure_record.py (94%) rename python/tvm/{ansor => auto_schedule}/utils.py (99%) rename python/tvm/{ansor => auto_schedule}/workload_registry.py (95%) rename src/{ansor => auto_schedule}/auto_schedule.cc (93%) rename src/{ansor => auto_schedule}/auto_schedule.h (91%) rename src/{ansor => auto_schedule}/compute_dag.cc (97%) rename src/{ansor => auto_schedule}/compute_dag.h (85%) rename src/{ansor => auto_schedule}/loop_state.cc (97%) rename src/{ansor => auto_schedule}/loop_state.h (93%) rename src/{ansor => auto_schedule}/measure.cc (93%) rename src/{ansor => auto_schedule}/measure.h (93%) rename src/{ansor => auto_schedule}/measure_record.cc (77%) rename src/{ansor => auto_schedule}/measure_record.h (91%) rename src/{ansor => auto_schedule}/search_policy/empty_policy.cc (94%) rename src/{ansor => auto_schedule}/search_policy/empty_policy.h (85%) rename src/{ansor => auto_schedule}/search_policy/search_policy.cc (84%) rename src/{ansor => auto_schedule}/search_policy/search_policy.h (92%) rename src/{ansor => auto_schedule}/search_task.cc (93%) rename src/{ansor => auto_schedule}/search_task.h (93%) rename src/{ansor => auto_schedule}/transform_step.cc (98%) rename src/{ansor => auto_schedule}/transform_step.h (92%) rename src/{ansor => auto_schedule}/utils.cc (93%) rename src/{ansor => auto_schedule}/utils.h (97%) rename tests/python/unittest/{test_ansor_common.py => test_auto_schedule_common.py} (89%) rename tests/python/unittest/{test_ansor_compute_dag.py => test_auto_schedule_compute_dag.py} (93%) rename tests/python/unittest/{test_ansor_loop_state.py => test_auto_schedule_loop_state.py} (89%) rename tests/python/unittest/{test_ansor_measure.py => test_auto_schedule_measure.py} (74%) rename tests/python/unittest/{test_ansor_search_policy.py => test_auto_schedule_search_policy.py} (79%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5550b5f6b3a8..9f5c5084d6c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,7 +185,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB_RECURSE COMPILER_SRCS - src/ansor/*.cc + src/auto_schedule/*.cc src/node/*.cc src/ir/*.cc src/arith/*.cc diff --git a/python/tvm/ansor/__init__.py b/python/tvm/auto_schedule/__init__.py similarity index 96% rename from python/tvm/ansor/__init__.py rename to python/tvm/auto_schedule/__init__.py index 216bfe25c89f..90bec8665cef 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/auto_schedule/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-import, redefined-builtin -""" Namespace for Ansor auto-scheduler. """ +""" Namespace for TVM Auto-scheduler. """ from . import compute_dag from . import measure diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/auto_schedule/_ffi_api.py similarity index 87% rename from python/tvm/ansor/_ffi_api.py rename to python/tvm/auto_schedule/_ffi_api.py index 622c6f6ea43d..9d2b9865ae95 100644 --- a/python/tvm/ansor/_ffi_api.py +++ b/python/tvm/auto_schedule/_ffi_api.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -""" Register FFI APIs from C++ for the namespace tvm.ansor. """ +""" Register FFI APIs from C++ for the namespace tvm.auto_schedule. """ import tvm._ffi -tvm._ffi._init_api("ansor", __name__) +tvm._ffi._init_api("auto_schedule", __name__) diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/auto_schedule/auto_schedule.py similarity index 89% rename from python/tvm/ansor/auto_schedule.py rename to python/tvm/auto_schedule/auto_schedule.py index 0ae6ede8a168..7b5f3b2ca593 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/auto_schedule/auto_schedule.py @@ -16,9 +16,9 @@ # under the License. """ -User interface for Ansor auto-scheduler. +User interface for TVM Auto-scheduler. -The basic schedule search process for Ansor is designed to be: +The basic schedule search process for TVM Auto-scheduler is designed to be: `Program sampling` -> `Performance Tuning`. In `Program sampling`, we use some predefined precise or heuristic rules to generate several @@ -34,7 +34,7 @@ from . import _ffi_api -@tvm._ffi.register_object("ansor.HardwareParams") +@tvm._ffi.register_object("auto_schedule.HardwareParams") class HardwareParams(Object): """ The parameters of target hardware used to guide the search process of SearchPolicy. @@ -55,7 +55,7 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes): vector_unit_bytes, cache_line_bytes) -@tvm._ffi.register_object("ansor.SearchTask") +@tvm._ffi.register_object("auto_schedule.SearchTask") class SearchTask(Object): """ The computation information and hardware parameters for a specific schedule search task. @@ -79,12 +79,12 @@ def __init__(self, dag, workload_key, target, target_host=None, hardware_params) -@tvm._ffi.register_object("ansor.SearchPolicy") +@tvm._ffi.register_object("auto_schedule.SearchPolicy") class SearchPolicy(Object): """ The base class of search policies. """ -@tvm._ffi.register_object("ansor.EmptyPolicy") +@tvm._ffi.register_object("auto_schedule.EmptyPolicy") class EmptyPolicy(SearchPolicy): """ This is an example empty search policy which will always generate the init state of ComputeDAG. @@ -93,7 +93,7 @@ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) -@tvm._ffi.register_object("ansor.TuningOptions") +@tvm._ffi.register_object("auto_schedule.TuningOptions") class TuningOptions(Object): """ This controls the options of performance tuning. @@ -121,12 +121,12 @@ class TuningOptions(Object): measure_callbacks: Optional[List[MeasureCallback]] Callback functions called after each measurement. Candidates: - - ansor.RecordToFile + - auto_schedule.RecordToFile pre_search_callbacks: Optional[List[SearchCallback]] Callback functions called before the search process. Candidates: - - ansor.PreloadMeasuredStates - - ansor.PreloadCustomSketchRule + - auto_schedule.PreloadMeasuredStates + - auto_schedule.PreloadCustomSketchRule TODO(jcf94): Add these implementation in later PRs. """ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64, @@ -137,7 +137,7 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r builder = LocalBuilder() else: raise ValueError("Invalid builder: " + builder) - elif not isinstance(builder, tvm.ansor.measure.ProgramBuilder): + elif not isinstance(builder, tvm.auto_schedule.measure.ProgramBuilder): raise ValueError("Invalid builder: " + builder + " . TuningOptions expects a ProgramBuilder or string.") @@ -146,7 +146,7 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r runner = LocalRunner() else: raise ValueError("Invalid runner: " + runner) - elif not isinstance(runner, tvm.ansor.measure.ProgramRunner): + elif not isinstance(runner, tvm.auto_schedule.measure.ProgramRunner): raise ValueError("Invalid runner: " + runner + " . TuningOptions expects a ProgramRunner or string.") @@ -177,7 +177,7 @@ def auto_schedule(task, search_policy='default', tuning_options=None): """ if not isinstance(task, SearchTask): raise ValueError("Invalid task: " + task + - " . `ansor.auto_schedule` expects a SearchTask.") + " . `auto_schedule.auto_schedule` expects a SearchTask.") if isinstance(search_policy, str): if search_policy == 'default': @@ -188,7 +188,7 @@ def auto_schedule(task, search_policy='default', tuning_options=None): raise ValueError("Invalid search policy: " + search_policy) elif not isinstance(search_policy, SearchPolicy): raise ValueError("Invalid search policy: " + search_policy + - " . `ansor.auto_schedule` expects a SearchPolicy or a string.") + " . `auto_schedule.auto_schedule` expects a SearchPolicy or a string.") sch, tensors = _ffi_api.AutoSchedule(task, search_policy, tuning_options if tuning_options else TuningOptions()) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/auto_schedule/compute_dag.py similarity index 92% rename from python/tvm/ansor/compute_dag.py rename to python/tvm/auto_schedule/compute_dag.py index 587225affa84..a4738a933b3e 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/auto_schedule/compute_dag.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" The Ansor computational graph and related program analyses. """ +""" The TVM Auto-scheduler computational graph and related program analyses. """ import hashlib @@ -30,10 +30,10 @@ from . import _ffi_api -@tvm._ffi.register_object("ansor.ComputeDAG") +@tvm._ffi.register_object("auto_schedule.ComputeDAG") class ComputeDAG(Object): """ - The Ansor computational graph and related program analyses. + The TVM Auto-scheduler computational graph and related program analyses. We convert a compute declaration described by `tvm.compute` (could be a single operator or a subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, @@ -41,9 +41,9 @@ class ComputeDAG(Object): total float operation count, consumer/producer relations of each operation stage, whether an operation stage should be tiled/compute inlined ...). These analyses can help the search policy to make decisions during search process. - ComputeDAG is also responsible for the interaction between Ansor `LoopState` and TVM schedule - (e.g. applying the `LoopState` transform steps to TVM schedule, providing `LoopState` with extra - information got from TVM schedule ...). + ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and + TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + `LoopState` with extra information got from TVM schedule ...). Parameters ---------- diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/auto_schedule/loop_state.py similarity index 98% rename from python/tvm/ansor/loop_state.py rename to python/tvm/auto_schedule/loop_state.py index 610392a22523..7b8804c8be60 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/auto_schedule/loop_state.py @@ -48,17 +48,17 @@ from . import _ffi_api -@tvm._ffi.register_object("ansor.Iterator") +@tvm._ffi.register_object("auto_schedule.Iterator") class Iterator(Object): """ A loop iterator structure. """ -@tvm._ffi.register_object("ansor.Stage") +@tvm._ffi.register_object("auto_schedule.Stage") class Stage(Object): """ A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """ -@tvm._ffi.register_object("ansor.State") +@tvm._ffi.register_object("auto_schedule.State") class StateObject(Object): """ The internal State object """ def __eq__(self, other): diff --git a/python/tvm/ansor/measure.py b/python/tvm/auto_schedule/measure.py similarity index 96% rename from python/tvm/ansor/measure.py rename to python/tvm/auto_schedule/measure.py index 25170ea7ebfa..9c8292126317 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/auto_schedule/measure.py @@ -54,12 +54,12 @@ # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool GLOBAL_BUILD_ARGUMENTS = None -@tvm._ffi.register_object("ansor.MeasureCallback") +@tvm._ffi.register_object("auto_schedule.MeasureCallback") class MeasureCallback(Object): """ The base class of measurement callback functions. """ -@tvm._ffi.register_object("ansor.MeasureInput") +@tvm._ffi.register_object("auto_schedule.MeasureInput") class MeasureInput(Object): """ Store the input of a measurement. @@ -74,7 +74,7 @@ def __init__(self, task, state): self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) -@tvm._ffi.register_object("ansor.BuildResult") +@tvm._ffi.register_object("auto_schedule.BuildResult") class BuildResult(Object): """ Store the result of a build. @@ -99,7 +99,7 @@ def __init__(self, filename, args, error_no, error_msg, time_cost): _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost) -@tvm._ffi.register_object("ansor.MeasureResult") +@tvm._ffi.register_object("auto_schedule.MeasureResult") class MeasureResult(Object): """ Store the results of a measurement. @@ -124,7 +124,7 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): error_msg, all_cost, timestamp) -@tvm._ffi.register_object("ansor.ProgramBuilder") +@tvm._ffi.register_object("auto_schedule.ProgramBuilder") class ProgramBuilder(Object): """ The base class of ProgramBuilders. """ @@ -145,7 +145,7 @@ def build(self, measure_inputs, verbose=True): return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose) -@tvm._ffi.register_object("ansor.ProgramRunner") +@tvm._ffi.register_object("auto_schedule.ProgramRunner") class ProgramRunner(Object): """ The base class of ProgramRunners. """ @@ -168,7 +168,7 @@ def run(self, measure_inputs, build_results, verbose=True): return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose) -@tvm._ffi.register_object("ansor.LocalBuilder") +@tvm._ffi.register_object("auto_schedule.LocalBuilder") class LocalBuilder(ProgramBuilder): """ LocalBuilder use local CPU cores to build programs in parallel. @@ -191,7 +191,7 @@ def __init__(self, _ffi_api.LocalBuilder, timeout, n_parallel, build_func) -@tvm._ffi.register_object("ansor.LocalRunner") +@tvm._ffi.register_object("auto_schedule.LocalRunner") class LocalRunner(ProgramRunner): """ LocalRunner that uses local CPU/GPU to measures the time cost of programs. @@ -334,7 +334,7 @@ def timed_func(): return res -@tvm._ffi.register_func("ansor.local_builder.build") +@tvm._ffi.register_func("auto_schedule.local_builder.build") def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=True): """ Build function of LocalBuilder to build the MeasureInputs to runnable modules. @@ -376,7 +376,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo return results -@tvm._ffi.register_func("ansor.local_runner.run") +@tvm._ffi.register_func("auto_schedule.local_runner.run") def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose=True): """ diff --git a/python/tvm/ansor/measure_record.py b/python/tvm/auto_schedule/measure_record.py similarity index 94% rename from python/tvm/ansor/measure_record.py rename to python/tvm/auto_schedule/measure_record.py index 46b94be18719..5f97e5d737d4 100644 --- a/python/tvm/ansor/measure_record.py +++ b/python/tvm/auto_schedule/measure_record.py @@ -25,7 +25,7 @@ from . import _ffi_api -@tvm._ffi.register_object("ansor.RecordToFile") +@tvm._ffi.register_object("auto_schedule.RecordToFile") class RecordToFile(MeasureCallback): """ A measurement callback that writes measurement records into a file. @@ -35,21 +35,21 @@ class RecordToFile(MeasureCallback): filename : str File name for this callback to write log to. """ - def __init__(self, filename="ansor_tuning.json"): + def __init__(self, filename="auto_schedule_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename) -@tvm._ffi.register_object("ansor.RecordReader") +@tvm._ffi.register_object("auto_schedule.RecordReader") class RecordReader(Object): """ Reader of the json log file. Parameters ---------- - filename : str = "ansor_tuning.json" + filename : str = "auto_schedule_tuning.json" File name for this reader to load log from. """ - def __init__(self, filename="ansor_tuning.json"): + def __init__(self, filename="auto_schedule_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename) def read_lines(self, max_lines=None, skip_lines=0): diff --git a/python/tvm/ansor/utils.py b/python/tvm/auto_schedule/utils.py similarity index 99% rename from python/tvm/ansor/utils.py rename to python/tvm/auto_schedule/utils.py index 6052ef626033..c29675074e22 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/auto_schedule/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Common utilities for ansor. """ +""" Common utilities for auto_schedule. """ from typing import Hashable import multiprocessing diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/auto_schedule/workload_registry.py similarity index 95% rename from python/tvm/ansor/workload_registry.py rename to python/tvm/auto_schedule/workload_registry.py index 450bb64cbc66..b50727ec955e 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/auto_schedule/workload_registry.py @@ -55,7 +55,7 @@ def register_workload(func_name, f=None, override=False): Examples -------- - @ansor.register_workload + @auto_schedule.register_workload def matmul(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') @@ -110,7 +110,7 @@ def make_workload_key(func, args): if not func_name in WORKLOAD_FUNC_REGISTRY: raise ValueError("%s is not registered. " % func, - "Please register it with @ansor.register_workload") + "Please register it with @auto_schedule.register_workload") args = serialize_args(args) @@ -137,11 +137,11 @@ def decode_workload_key_to_func_args(workload_key): workload = json.loads(workload_key) if not workload[0] in WORKLOAD_FUNC_REGISTRY: raise ValueError("%s is not registered. " % workload[0] + - "Please register it with @ansor.register_workload") + "Please register it with @auto_schedule.register_workload") return workload[0], deserialize_args(workload[1:]) -@tvm._ffi.register_func("ansor.workload_key_to_tensors") +@tvm._ffi.register_func("auto_schedule.workload_key_to_tensors") def workload_key_to_tensors(workload_key): """ Get the input/output tensors from the workload key. diff --git a/src/ansor/auto_schedule.cc b/src/auto_schedule/auto_schedule.cc similarity index 93% rename from src/ansor/auto_schedule.cc rename to src/auto_schedule/auto_schedule.cc index 184989ca1db8..72484a93b5e6 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/auto_schedule/auto_schedule.cc @@ -18,8 +18,8 @@ */ /*! - * \file ansor/auto_schedule.cc - * \brief The user interface of the Ansor auto-scheduler. + * \file auto_schedule/auto_schedule.cc + * \brief The user interface of the TVM Auto-scheduler. */ #include "auto_schedule.h" @@ -27,7 +27,7 @@ #include namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_NODE_TYPE(TuningOptionsNode); @@ -61,7 +61,7 @@ std::pair> AutoSchedule(SearchTask task, SearchP return task->compute_dag.ApplySteps(state->transform_steps); } -TVM_REGISTER_GLOBAL("ansor.TuningOptions") +TVM_REGISTER_GLOBAL("auto_schedule.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, bool verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, @@ -70,12 +70,12 @@ TVM_REGISTER_GLOBAL("ansor.TuningOptions") builder, runner, measure_callbacks, pre_search_callbacks); }); -TVM_REGISTER_GLOBAL("ansor.AutoSchedule") +TVM_REGISTER_GLOBAL("auto_schedule.AutoSchedule") .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuningOptions tuning_options) { te::Schedule sch; Array return_tensors; std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tuning_options); return Array{sch, return_tensors}; }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/auto_schedule/auto_schedule.h similarity index 91% rename from src/ansor/auto_schedule.h rename to src/auto_schedule/auto_schedule.h index 84979f040cac..05d3db6c28d7 100644 --- a/src/ansor/auto_schedule.h +++ b/src/auto_schedule/auto_schedule.h @@ -18,14 +18,14 @@ */ /*! - * \file ansor/auto_schedule.h - * \brief The user interface of the Ansor auto-scheduler. This is the entry structure to get + * \file auto_schedule/auto_schedule.h + * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get * schedule search requirements from upper level (Python API), and returns a high performance * schedule after search process. */ -#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ -#define TVM_ANSOR_AUTO_SCHEDULE_H_ +#ifndef TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_ +#define TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_ #include @@ -33,7 +33,7 @@ #include "search_policy/search_policy.h" namespace tvm { -namespace ansor { +namespace auto_schedule { /*! \brief Tuning and measurement options. */ class TuningOptionsNode : public Object { @@ -69,7 +69,7 @@ class TuningOptionsNode : public Object { v->Visit("pre_search_callbacks", &pre_search_callbacks); } - static constexpr const char* _type_key = "ansor.TuningOptions"; + static constexpr const char* _type_key = "auto_schedule.TuningOptions"; TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object); }; @@ -110,7 +110,7 @@ class TuningOptions : public ObjectRef { TVM_DLL std::pair> AutoSchedule(SearchTask task, SearchPolicy search_policy, TuningOptions tuning_options); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ +#endif // TVM_AUTO_SCHEDULE_AUTO_SCHEDULE_H_ diff --git a/src/ansor/compute_dag.cc b/src/auto_schedule/compute_dag.cc similarity index 97% rename from src/ansor/compute_dag.cc rename to src/auto_schedule/compute_dag.cc index 35c8daafad29..312a25ad62dd 100644 --- a/src/ansor/compute_dag.cc +++ b/src/auto_schedule/compute_dag.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/compute_dag.cc + * \file auto_schedule/compute_dag.cc * \brief Compute declaration graph and its related analysis tools. */ @@ -40,7 +40,7 @@ #include "utils.h" namespace tvm { -namespace ansor { +namespace auto_schedule { using namespace tvm::tir; @@ -258,7 +258,7 @@ std::pair> ComputeDAG::ApplySteps( } } // Create the initial schedule - // TODO(jcf94): Currently we only checked single output dag for Ansor, + // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler, // update this after testing with multiple outputs. te::Schedule schedule = te::create_schedule({ops.back()}); @@ -298,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } } // Create the initial schedule - // TODO(jcf94): Currently we only checked single output dag for Ansor, + // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler, // update this after testing with multiple outputs. te::Schedule schedule = te::create_schedule({ops.back()}); @@ -450,11 +450,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ss.str(); }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAG").set_body_typed([](Array tensors) { +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAG").set_body_typed([](Array tensors) { return ComputeDAG(tensors); }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { te::Schedule sch; Array return_tensors; @@ -462,15 +462,15 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") return Array{sch, return_tensors}; }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGPrintPythonCodeFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { return dag.PrintStepsAsPython(state->transform_steps); }); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState") +TVM_REGISTER_GLOBAL("auto_schedule.ComputeDAGInferBoundFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { return dag.InferBound(state); }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/auto_schedule/compute_dag.h similarity index 85% rename from src/ansor/compute_dag.h rename to src/auto_schedule/compute_dag.h index 8c244bd87778..bb582d32ee7e 100644 --- a/src/ansor/compute_dag.h +++ b/src/auto_schedule/compute_dag.h @@ -18,8 +18,8 @@ */ /*! - * \file ansor/compute_dag.h - * \brief The Ansor computational graph and related program analyses. + * \file auto_schedule/compute_dag.h + * \brief The TVM Auto-scheduler computational graph and related program analyses. * * We convert a compute declaration described by `tvm.compute` (could be a single operator or a * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, @@ -27,13 +27,13 @@ * total float operation count, consumer/producer relations of each operation stage, whether an * operation stage should be tiled/compute inlined ...). These analyses can help the search policy * to make decisions during search process. - * ComputeDAG is also responsible for the interaction between Ansor `LoopState` and TVM schedule - * (e.g. applying the `LoopState` transform steps to TVM schedule, providing `LoopState` with extra - * information got from TVM schedule ...). + * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and + * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + * `LoopState` with extra information got from TVM schedule ...). */ -#ifndef TVM_ANSOR_COMPUTE_DAG_H_ -#define TVM_ANSOR_COMPUTE_DAG_H_ +#ifndef TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_ +#define TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_ #include @@ -42,9 +42,9 @@ #include "loop_state.h" namespace tvm { -namespace ansor { +namespace auto_schedule { -/*! \brief The Ansor computational graph and related program analyses. */ +/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */ class ComputeDAGNode : public Object { public: /*! @@ -67,7 +67,7 @@ class ComputeDAGNode : public Object { v->Visit("init_state", &init_state); } - static constexpr const char* _type_key = "ansor.ComputeDAG"; + static constexpr const char* _type_key = "auto_schedule.ComputeDAG"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); }; @@ -118,7 +118,7 @@ class ComputeDAG : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_COMPUTE_DAG_H_ +#endif // TVM_AUTO_SCHEDULE_COMPUTE_DAG_H_ diff --git a/src/ansor/loop_state.cc b/src/auto_schedule/loop_state.cc similarity index 97% rename from src/ansor/loop_state.cc rename to src/auto_schedule/loop_state.cc index 00d2bc759eb6..666efcebce04 100644 --- a/src/ansor/loop_state.cc +++ b/src/auto_schedule/loop_state.cc @@ -18,9 +18,9 @@ */ /*! - * \file ansor/loop_state.cc + * \file auto_schedule/loop_state.cc * \brief An lightweight IR (intermediate representation) for loop structures. - * see ansor/loop_state.h for more explanation. + * see auto_schedule/loop_state.h for more explanation. */ #include "loop_state.h" @@ -34,7 +34,7 @@ #include "utils.h" namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_OBJECT_TYPE(StepNode); TVM_REGISTER_NODE_TYPE(StageNode); @@ -386,28 +386,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /********** State interface API for ffi **********/ -TVM_REGISTER_GLOBAL("ansor.StateReorder") +TVM_REGISTER_GLOBAL("auto_schedule.StateReorder") .set_body_typed([](State state, int stage_id, const Array& order) { state.reorder(stage_id, order); return state; }); -TVM_REGISTER_GLOBAL("ansor.StateSplit") +TVM_REGISTER_GLOBAL("auto_schedule.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer) { const auto& res = state.split(stage_id, it, lengths, inner_to_outer); return Array{state, res}; }); -TVM_REGISTER_GLOBAL("ansor.StateFuse") +TVM_REGISTER_GLOBAL("auto_schedule.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { const auto& res = state.fuse(stage_id, iters); return Array{state, res}; }); -TVM_REGISTER_GLOBAL("ansor.StateEqual").set_body_typed([](State state1, State state2) { +TVM_REGISTER_GLOBAL("auto_schedule.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/loop_state.h b/src/auto_schedule/loop_state.h similarity index 93% rename from src/ansor/loop_state.h rename to src/auto_schedule/loop_state.h index c91a65cee528..5ba47b7263a1 100644 --- a/src/ansor/loop_state.h +++ b/src/auto_schedule/loop_state.h @@ -18,7 +18,7 @@ */ /*! - * \file ansor/loop_state.h + * \file auto_schedule/loop_state.h * \brief The definition of the "state" in search. * * Each LoopState corresponds to a schedule for its ComputeDAG. @@ -45,8 +45,8 @@ * copy on write style. All objects are immutable, which is similar to TVM IR. */ -#ifndef TVM_ANSOR_LOOP_STATE_H_ -#define TVM_ANSOR_LOOP_STATE_H_ +#ifndef TVM_AUTO_SCHEDULE_LOOP_STATE_H_ +#define TVM_AUTO_SCHEDULE_LOOP_STATE_H_ #include @@ -55,7 +55,7 @@ #include "transform_step.h" namespace tvm { -namespace ansor { +namespace auto_schedule { using namespace tvm::tir; @@ -135,7 +135,7 @@ class IteratorNode : public Object { v->Visit("range", &range); } - static constexpr const char* _type_key = "ansor.Iterator"; + static constexpr const char* _type_key = "auto_schedule.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; @@ -187,7 +187,7 @@ class StageNode : public Object { v->Visit("iters", &iters); } - static constexpr const char* _type_key = "ansor.Stage"; + static constexpr const char* _type_key = "auto_schedule.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; @@ -241,7 +241,7 @@ class StateNode : public Object { v->Visit("concrete", &concrete); } - static constexpr const char* _type_key = "ansor.State"; + static constexpr const char* _type_key = "auto_schedule.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); private: @@ -347,22 +347,22 @@ class State : public ObjectRef { const Array>& lengths, bool inner_to_outer); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm // Hash and equal function for State namespace std { -/*! \brief The hash function for ansor::State. */ +/*! \brief The hash function for auto_schedule::State. */ template <> -struct hash<::tvm::ansor::State> { - std::size_t operator()(const ::tvm::ansor::State& state) const { +struct hash<::tvm::auto_schedule::State> { + std::size_t operator()(const ::tvm::auto_schedule::State& state) const { return tvm::runtime::ObjectHash()(state.ToStr()); } }; /*! - * \brief The equal_to function for ansor::State. + * \brief The equal_to function for auto_schedule::State. * We use the schedule result(its string format) of a state to check if two states are `euqal`. * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts @@ -370,12 +370,13 @@ struct hash<::tvm::ansor::State> { * to split from outter to inner by factors [8, 16]) */ template <> -struct equal_to<::tvm::ansor::State> { - bool operator()(const ::tvm::ansor::State& lhs, const ::tvm::ansor::State& rhs) const { +struct equal_to<::tvm::auto_schedule::State> { + bool operator()(const ::tvm::auto_schedule::State& lhs, + const ::tvm::auto_schedule::State& rhs) const { return lhs.ToStr() == rhs.ToStr(); } }; } // namespace std -#endif // TVM_ANSOR_LOOP_STATE_H_ +#endif // TVM_AUTO_SCHEDULE_LOOP_STATE_H_ diff --git a/src/ansor/measure.cc b/src/auto_schedule/measure.cc similarity index 93% rename from src/ansor/measure.cc rename to src/auto_schedule/measure.cc index 2a2f65c0590c..86a72163c682 100644 --- a/src/ansor/measure.cc +++ b/src/auto_schedule/measure.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/measure.cc + * \file auto_schedule/measure.cc * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. */ @@ -31,7 +31,7 @@ #include "utils.h" namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_NODE_TYPE(MeasureInputNode); TVM_REGISTER_NODE_TYPE(BuildResultNode); @@ -111,11 +111,11 @@ LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func } Array LocalBuilderNode::Build(const Array& inputs, bool verbose) { - if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { + if (const auto* f = runtime::Registry::Get("auto_schedule.local_builder.build")) { Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; } - LOG(FATAL) << "ansor.local_builder.build is not registered. " + LOG(FATAL) << "auto_schedule.local_builder.build is not registered. " << "This is a function registered in Python, " << "make sure the TVM Python runtime has been loaded successfully."; throw; @@ -135,12 +135,12 @@ LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, Array LocalRunnerNode::Run(const Array& inputs, const Array& build_results, bool verbose) { - if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { + if (const auto* f = runtime::Registry::Get("auto_schedule.local_runner.run")) { Array results = (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose); return results; } - LOG(FATAL) << "ansor.local_runner.run is not registered. " + LOG(FATAL) << "auto_schedule.local_runner.run is not registered. " << "This is a function registered in Python, " << "make sure the TVM Python runtime has been loaded successfully."; throw; @@ -294,41 +294,41 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /********** Measure interface API for ffi **********/ -TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) { +TVM_REGISTER_GLOBAL("auto_schedule.MeasureInput").set_body_typed([](SearchTask task, State state) { return MeasureInput(task, state); }); -TVM_REGISTER_GLOBAL("ansor.BuildResult") +TVM_REGISTER_GLOBAL("auto_schedule.BuildResult") .set_body_typed([](String filename, Array args, int error_no, String error_msg, double time_cost) { return BuildResult(filename, args, error_no, error_msg, time_cost); }); -TVM_REGISTER_GLOBAL("ansor.MeasureResult") +TVM_REGISTER_GLOBAL("auto_schedule.MeasureResult") .set_body_typed([](Array costs, int error_no, String error_msg, double all_cost, double timestamp) { return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); }); -TVM_REGISTER_GLOBAL("ansor.ProgramBuilderBuild") +TVM_REGISTER_GLOBAL("auto_schedule.ProgramBuilderBuild") .set_body_typed([](const ProgramBuilder& builder, const Array& inputs, bool verbose) { return builder->Build(inputs, verbose); }); -TVM_REGISTER_GLOBAL("ansor.ProgramRunnerRun") +TVM_REGISTER_GLOBAL("auto_schedule.ProgramRunnerRun") .set_body_typed([](const ProgramRunner& runner, const Array& inputs, const Array& build_results, bool verbose) { return runner->Run(inputs, build_results, verbose); }); -TVM_REGISTER_GLOBAL("ansor.LocalBuilder") +TVM_REGISTER_GLOBAL("auto_schedule.LocalBuilder") .set_body_typed([](int timeout, int n_parallel, const String& build_func) { return LocalBuilder(timeout, n_parallel, build_func); }); -TVM_REGISTER_GLOBAL("ansor.LocalRunner") +TVM_REGISTER_GLOBAL("auto_schedule.LocalRunner") .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval) { return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/measure.h b/src/auto_schedule/measure.h similarity index 93% rename from src/ansor/measure.h rename to src/auto_schedule/measure.h index 0b8d2dd5649e..c4f776abb003 100644 --- a/src/ansor/measure.h +++ b/src/auto_schedule/measure.h @@ -18,13 +18,13 @@ */ /*! - * \file ansor/measure.h + * \file auto_schedule/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. * The flow of data structures is MeasureInput -> BuildeResult -> MeasureResult. */ -#ifndef TVM_ANSOR_MEASURE_H_ -#define TVM_ANSOR_MEASURE_H_ +#ifndef TVM_AUTO_SCHEDULE_MEASURE_H_ +#define TVM_AUTO_SCHEDULE_MEASURE_H_ #include #include @@ -34,7 +34,7 @@ #include "search_task.h" namespace tvm { -namespace ansor { +namespace auto_schedule { class SearchPolicy; class MeasureInput; @@ -80,7 +80,7 @@ class MeasureInputNode : public Object { /*! \brief Do shallow copy. */ MeasureInput copy() const; - static constexpr const char* _type_key = "ansor.MeasureInput"; + static constexpr const char* _type_key = "auto_schedule.MeasureInput"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); }; @@ -122,7 +122,7 @@ class BuildResultNode : public Object { v->Visit("time_cost", &time_cost); } - static constexpr const char* _type_key = "ansor.BuildResult"; + static constexpr const char* _type_key = "auto_schedule.BuildResult"; TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); }; @@ -170,7 +170,7 @@ class MeasureResultNode : public Object { /*! \brief Do shallow copy. */ MeasureResult copy() const; - static constexpr const char* _type_key = "ansor.MeasureResult"; + static constexpr const char* _type_key = "auto_schedule.MeasureResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); }; @@ -206,7 +206,7 @@ class MeasureCallbackNode : public Object { */ virtual void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) = 0; - static constexpr const char* _type_key = "ansor.MeasureCallback"; + static constexpr const char* _type_key = "auto_schedule.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); }; @@ -238,7 +238,7 @@ class ProgramBuilderNode : public Object { */ virtual Array Build(const Array& inputs, bool verbose) = 0; - static constexpr const char* _type_key = "ansor.ProgramBuilder"; + static constexpr const char* _type_key = "auto_schedule.ProgramBuilder"; TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object); }; @@ -268,7 +268,7 @@ class ProgramRunnerNode : public Object { virtual Array Run(const Array& inputs, const Array& build_results, bool verbose) = 0; - static constexpr const char* _type_key = "ansor.ProgramRunner"; + static constexpr const char* _type_key = "auto_schedule.ProgramRunner"; TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object); }; @@ -291,7 +291,7 @@ class LocalBuilderNode : public ProgramBuilderNode { Array Build(const Array& inputs, bool verbose) final; - static constexpr const char* _type_key = "ansor.LocalBuilder"; + static constexpr const char* _type_key = "auto_schedule.LocalBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode); }; @@ -328,7 +328,7 @@ class LocalRunnerNode : public ProgramRunnerNode { Array Run(const Array& inputs, const Array& build_results, bool verbose) final; - static constexpr const char* _type_key = "ansor.LocalRunner"; + static constexpr const char* _type_key = "auto_schedule.LocalRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode); }; @@ -339,8 +339,8 @@ class LocalRunnerNode : public ProgramRunnerNode { class LocalRunner : public ProgramRunner { public: /*! - * \brief The constructor. See the corresponding class in python/tvm/ansor/measure.py for more - * detailed parameter explaination. + * \brief The constructor. See the corresponding class in python/tvm/auto_schedule/measure.py + * for more detailed parameter explaination. * \param timeout The timeout limit (in second) for each run. * This is used in a wrapper of the multiprocessing.Process.join(). * \param number Number of measure times. @@ -406,7 +406,7 @@ class ProgramMeasurerNode : public Object { /*! \brief The default max continuous error setting. */ static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; - static constexpr const char* _type_key = "ansor.ProgramMeasurer"; + static constexpr const char* _type_key = "auto_schedule.ProgramMeasurer"; TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); }; @@ -432,7 +432,7 @@ class ProgramMeasurer : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_MEASURE_H_ +#endif // TVM_AUTO_SCHEDULE_MEASURE_H_ diff --git a/src/ansor/measure_record.cc b/src/auto_schedule/measure_record.cc similarity index 77% rename from src/ansor/measure_record.cc rename to src/auto_schedule/measure_record.cc index 7a3573fa18a3..99bd5917f7c8 100644 --- a/src/ansor/measure_record.cc +++ b/src/auto_schedule/measure_record.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/measure_record.cc + * \file auto_schedule/measure_record.cc * \brief Json serialization format for dumping and loading tuning records. */ @@ -62,13 +62,14 @@ inline std::vector IntArrayToVector( } template <> -struct Handler<::tvm::Array<::tvm::ansor::Stage>> { +struct Handler<::tvm::Array<::tvm::auto_schedule::Stage>> { inline static void Write(dmlc::JSONWriter* writer, - const ::tvm::Array<::tvm::ansor::Stage>& data) { + const ::tvm::Array<::tvm::auto_schedule::Stage>& data) { writer->BeginArray(false); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::ansor::Stage>* data) { + inline static void Read(dmlc::JSONReader* reader, + ::tvm::Array<::tvm::auto_schedule::Stage>* data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); @@ -77,24 +78,26 @@ struct Handler<::tvm::Array<::tvm::ansor::Stage>> { }; template <> -struct Handler<::tvm::Array<::tvm::ansor::Step>> { - inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::ansor::Step>& data) { +struct Handler<::tvm::Array<::tvm::auto_schedule::Step>> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::Array<::tvm::auto_schedule::Step>& data) { writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { writer->WriteArraySeperator(); writer->BeginArray(false); - if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { + if (auto ps = data[i].as<::tvm::auto_schedule::ReorderStepNode>()) { writer->WriteArrayItem(std::string("RE")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(IntArrayToVector(ps->after_ids)); - } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { + } else if (auto ps = data[i].as<::tvm::auto_schedule::SplitStepNode>()) { writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->extent ? ::tvm::ansor::GetIntImm(ps->extent.value()) : 0); + writer->WriteArrayItem(ps->extent ? ::tvm::auto_schedule::GetIntImm(ps->extent.value()) + : 0); writer->WriteArrayItem(IntArrayToVector(ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); - } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { + } else if (auto ps = data[i].as<::tvm::auto_schedule::FuseStepNode>()) { writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(IntArrayToVector(ps->fused_ids)); @@ -106,7 +109,8 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::ansor::Step>* data) { + inline static void Read(dmlc::JSONReader* reader, + ::tvm::Array<::tvm::auto_schedule::Step>* data) { std::vector int_list; bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; @@ -130,7 +134,7 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { for (const auto& i : int_list) { after_ids.push_back(i); } - data->push_back(::tvm::ansor::ReorderStep(stage_id, after_ids)); + data->push_back(::tvm::auto_schedule::ReorderStep(stage_id, after_ids)); } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); @@ -151,7 +155,7 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { for (const auto& i : int_list) { lengths.push_back(::tvm::Integer(i)); } - data->push_back(::tvm::ansor::SplitStep( + data->push_back(::tvm::auto_schedule::SplitStep( stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); } else if (name == "FU") { s = reader->NextArrayItem(); @@ -164,7 +168,7 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { for (const auto& i : int_list) { fused_ids.push_back(i); } - data->push_back(::tvm::ansor::FuseStep(stage_id, fused_ids)); + data->push_back(::tvm::auto_schedule::FuseStep(stage_id, fused_ids)); } else { LOG(FATAL) << "Invalid step format"; } @@ -175,14 +179,14 @@ struct Handler<::tvm::Array<::tvm::ansor::Step>> { }; template <> -struct Handler<::tvm::ansor::StateNode> { - inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::StateNode& data) { +struct Handler<::tvm::auto_schedule::StateNode> { + inline static void Write(dmlc::JSONWriter* writer, const ::tvm::auto_schedule::StateNode& data) { writer->BeginArray(false); writer->WriteArrayItem(data.stages); writer->WriteArrayItem(data.transform_steps); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::StateNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::StateNode* data) { reader->BeginArray(); bool s; s = reader->NextArrayItem(); @@ -197,14 +201,15 @@ struct Handler<::tvm::ansor::StateNode> { }; template <> -struct Handler<::tvm::ansor::SearchTaskNode> { - inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::SearchTaskNode& data) { +struct Handler<::tvm::auto_schedule::SearchTaskNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::auto_schedule::SearchTaskNode& data) { writer->BeginArray(false); writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::SearchTaskNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::SearchTaskNode* data) { std::string target_str; bool s; @@ -223,17 +228,18 @@ struct Handler<::tvm::ansor::SearchTaskNode> { }; template <> -struct Handler<::tvm::ansor::MeasureInputNode> { - inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::MeasureInputNode& data) { +struct Handler<::tvm::auto_schedule::MeasureInputNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::auto_schedule::MeasureInputNode& data) { writer->BeginArray(false); writer->WriteArrayItem(*data.task.operator->()); writer->WriteArrayItem(*data.state.operator->()); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::MeasureInputNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::MeasureInputNode* data) { bool s; - auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>(); - auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>(); + auto task_node = ::tvm::make_object<::tvm::auto_schedule::SearchTaskNode>(); + auto state_node = ::tvm::make_object<::tvm::auto_schedule::StateNode>(); state_node->concrete = true; reader->BeginArray(); @@ -246,14 +252,15 @@ struct Handler<::tvm::ansor::MeasureInputNode> { s = reader->NextArrayItem(); CHECK(!s); - data->task = ::tvm::ansor::SearchTask(task_node); - data->state = ::tvm::ansor::State(state_node); + data->task = ::tvm::auto_schedule::SearchTask(task_node); + data->state = ::tvm::auto_schedule::State(state_node); } }; template <> -struct Handler<::tvm::ansor::MeasureResultNode> { - inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::MeasureResultNode& data) { +struct Handler<::tvm::auto_schedule::MeasureResultNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::auto_schedule::MeasureResultNode& data) { writer->BeginArray(false); writer->WriteArraySeperator(); writer->BeginArray(false); @@ -268,7 +275,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { writer->WriteArrayItem(static_cast((data.timestamp))); writer->EndArray(); } - inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::MeasureResultNode* data) { + inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_schedule::MeasureResultNode* data) { bool s; std::vector tmp; @@ -298,12 +305,12 @@ struct Handler<::tvm::ansor::MeasureResultNode> { } // namespace dmlc namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_OBJECT_TYPE(RecordToFileNode); TVM_REGISTER_OBJECT_TYPE(RecordReaderNode); -const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) +const std::string AUTO_SCHEDULE_LOG_VERSION = "v0.2"; // NOLINT(*) RecordToFile::RecordToFile(String filename) { auto node = make_object(); @@ -318,7 +325,7 @@ void WriteMeasureRecords(std::ostream* os, const Array& inputs, writer.BeginObject(false); writer.WriteObjectKeyValue("i", *inputs[i].operator->()); writer.WriteObjectKeyValue("r", *results[i].operator->()); - writer.WriteObjectKeyValue("v", ANSOR_LOG_VERSION); + writer.WriteObjectKeyValue("v", AUTO_SCHEDULE_LOG_VERSION); writer.EndObject(); *os << "\n"; } @@ -398,21 +405,21 @@ std::pair, Array> RecordReaderNode::ReadLines return std::make_pair(inputs, results); } -TVM_REGISTER_GLOBAL("ansor.RecordToFile").set_body_typed([](const String& filename) { +TVM_REGISTER_GLOBAL("auto_schedule.RecordToFile").set_body_typed([](const String& filename) { return RecordToFile(filename); }); -TVM_REGISTER_GLOBAL("ansor.RecordReader").set_body_typed([](const String& filename) { +TVM_REGISTER_GLOBAL("auto_schedule.RecordReader").set_body_typed([](const String& filename) { return RecordReader(filename); }); -TVM_REGISTER_GLOBAL("ansor.RecordReaderReadLines") +TVM_REGISTER_GLOBAL("auto_schedule.RecordReaderReadLines") .set_body_typed([](RecordReader reader, int size, int skip_size) { const auto& res = reader->ReadLines(size, skip_size); return Array{res.first, res.second}; }); -TVM_REGISTER_GLOBAL("ansor.RecordReaderReadNext").set_body_typed([](RecordReader reader) { +TVM_REGISTER_GLOBAL("auto_schedule.RecordReaderReadNext").set_body_typed([](RecordReader reader) { auto inp = make_object(); auto res = make_object(); if (reader->ReadNext(inp.get(), res.get())) { @@ -422,10 +429,10 @@ TVM_REGISTER_GLOBAL("ansor.RecordReaderReadNext").set_body_typed([](RecordReader } }); -TVM_REGISTER_GLOBAL("ansor.SaveRecords") +TVM_REGISTER_GLOBAL("auto_schedule.SaveRecords") .set_body_typed([](String filename, Array in, Array res) { std::ofstream ofs(filename, std::ofstream::app); WriteMeasureRecords(&ofs, in, res); }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/measure_record.h b/src/auto_schedule/measure_record.h similarity index 91% rename from src/ansor/measure_record.h rename to src/auto_schedule/measure_record.h index f14dabfd8fa3..f97e30ae9f70 100644 --- a/src/ansor/measure_record.h +++ b/src/auto_schedule/measure_record.h @@ -18,12 +18,12 @@ */ /*! - * \file ansor/measure_record.h + * \file auto_schedule/measure_record.h * \brief Json serialization format for dumping and loading tuning records. */ -#ifndef TVM_ANSOR_MEASURE_RECORD_H_ -#define TVM_ANSOR_MEASURE_RECORD_H_ +#ifndef TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_ +#define TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_ #include #include @@ -32,7 +32,7 @@ #include "measure.h" namespace tvm { -namespace ansor { +namespace auto_schedule { /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { @@ -43,7 +43,7 @@ class RecordToFileNode : public MeasureCallbackNode { void Callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; - static constexpr const char* _type_key = "ansor.RecordToFile"; + static constexpr const char* _type_key = "auto_schedule.RecordToFile"; TVM_DECLARE_FINAL_OBJECT_INFO(RecordToFileNode, MeasureCallbackNode); }; @@ -88,7 +88,7 @@ class RecordReaderNode : public Object { std::pair, Array> ReadLines(int max_size = -1, int skip_size = 0); - static constexpr const char* _type_key = "ansor.RecordReader"; + static constexpr const char* _type_key = "auto_schedule.RecordReader"; TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); private: @@ -130,7 +130,7 @@ void WriteMeasureRecords(std::ostream* os, const Array& inputs, void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_MEASURE_RECORD_H_ +#endif // TVM_AUTO_SCHEDULE_MEASURE_RECORD_H_ diff --git a/src/ansor/search_policy/empty_policy.cc b/src/auto_schedule/search_policy/empty_policy.cc similarity index 94% rename from src/ansor/search_policy/empty_policy.cc rename to src/auto_schedule/search_policy/empty_policy.cc index ea3325bd09e1..1bda033e24a1 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/auto_schedule/search_policy/empty_policy.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/search_policy/empty_policy.cc + * \file auto_schedule/search_policy/empty_policy.cc * \brief This is an brief example of search policy. */ @@ -29,7 +29,7 @@ #include "../measure.h" namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); @@ -92,9 +92,9 @@ Array EmptyPolicyNode::SearchOneRound() { return res; } -TVM_REGISTER_GLOBAL("ansor.EmptyPolicy").set_body_typed([]() { +TVM_REGISTER_GLOBAL("auto_schedule.EmptyPolicy").set_body_typed([]() { return EmptyPolicy(make_object()); }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/search_policy/empty_policy.h b/src/auto_schedule/search_policy/empty_policy.h similarity index 85% rename from src/ansor/search_policy/empty_policy.h rename to src/auto_schedule/search_policy/empty_policy.h index 3757ec281836..a718b3d1de5f 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/auto_schedule/search_policy/empty_policy.h @@ -18,19 +18,19 @@ */ /*! - * \file ansor/search_policy/empty_policy.h + * \file auto_schedule/search_policy/empty_policy.h * \brief A brief example of the search policy which always returns the initial naive schedule * (state). */ -#ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ +#ifndef TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_ +#define TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_ #include "../loop_state.h" #include "search_policy.h" namespace tvm { -namespace ansor { +namespace auto_schedule { /*! * \brief A brief example of the search policy which always returns the initial naive schedule @@ -44,7 +44,7 @@ class EmptyPolicyNode : public SearchPolicyNode { int num_measures_per_round, bool verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) final; - static constexpr const char* _type_key = "ansor.EmptyPolicy"; + static constexpr const char* _type_key = "auto_schedule.EmptyPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); private: @@ -64,7 +64,7 @@ class EmptyPolicy : public SearchPolicy { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ +#endif // TVM_AUTO_SCHEDULE_SEARCH_POLICY_EMPTY_POLICY_H_ diff --git a/src/ansor/search_policy/search_policy.cc b/src/auto_schedule/search_policy/search_policy.cc similarity index 84% rename from src/ansor/search_policy/search_policy.cc rename to src/auto_schedule/search_policy/search_policy.cc index 138d9f10639c..e2f977300f35 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/auto_schedule/search_policy/search_policy.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/search_policy/search_policy.cc + * \file auto_schedule/search_policy/search_policy.cc * \brief The base class of search policies. */ @@ -27,7 +27,7 @@ #include namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); @@ -40,16 +40,16 @@ void SearchPolicyNode::RunCallbacks(const Optional>& callb } } -TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") +TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicyRunCallbacks") .set_body_typed([](SearchPolicy policy, Optional> callbacks) { policy->RunCallbacks(callbacks); }); -TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") +TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetTask") .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; }); -TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") +TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetVerbose") .set_body_typed([](SearchPolicy policy, bool verbose) { policy->verbose = verbose; }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.h b/src/auto_schedule/search_policy/search_policy.h similarity index 92% rename from src/ansor/search_policy/search_policy.h rename to src/auto_schedule/search_policy/search_policy.h index 0edd07bd0ad6..224797282219 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/auto_schedule/search_policy/search_policy.h @@ -18,11 +18,11 @@ */ /*! - * \file ansor/search_policy/search_policy.h + * \file auto_schedule/search_policy/search_policy.h * \brief The base class of search policies, including the abstract definition of search policy and * other supporting data structures. * - * The basic schedule search process for Ansor is design to be: + * The basic schedule search process for TVM Auto-scheduler is design to be: * `Program sampling` -> `Performance Tuning`. * * In `Program sampling`, we use some predefined precise or heuristic rules to generate several @@ -48,8 +48,8 @@ * during the search process. */ -#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#ifndef TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_ #include @@ -59,7 +59,7 @@ #include "../search_task.h" namespace tvm { -namespace ansor { +namespace auto_schedule { class ProgramMeasurer; class SearchPolicyNode; @@ -77,7 +77,7 @@ class SearchCallbackNode : public Object { */ virtual void Callback(SearchPolicyNode* policy) = 0; - static constexpr const char* _type_key = "ansor.SearchCallback"; + static constexpr const char* _type_key = "auto_schedule.SearchCallback"; TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); }; @@ -131,7 +131,7 @@ class SearchPolicyNode : public Object { */ void RunCallbacks(const Optional>& callbacks); - static constexpr const char* _type_key = "ansor.SearchPolicy"; + static constexpr const char* _type_key = "auto_schedule.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); protected: @@ -161,7 +161,7 @@ class SearchPolicy : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#endif // TVM_AUTO_SCHEDULE_SEARCH_POLICY_SEARCH_POLICY_H_ diff --git a/src/ansor/search_task.cc b/src/auto_schedule/search_task.cc similarity index 93% rename from src/ansor/search_task.cc rename to src/auto_schedule/search_task.cc index 633d99fa41a8..1d7a08cc73db 100644 --- a/src/ansor/search_task.cc +++ b/src/auto_schedule/search_task.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/search_task.cc + * \file auto_schedule/search_task.cc * \brief Meta information and hardware parameters for a search task. */ @@ -30,7 +30,7 @@ #include namespace tvm { -namespace ansor { +namespace auto_schedule { TVM_REGISTER_NODE_TYPE(HardwareParamsNode); TVM_REGISTER_NODE_TYPE(SearchTaskNode); @@ -69,16 +69,16 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ansor.HardwareParams") +TVM_REGISTER_GLOBAL("auto_schedule.HardwareParams") .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) { return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes); }); -TVM_REGISTER_GLOBAL("ansor.SearchTask") +TVM_REGISTER_GLOBAL("auto_schedule.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); }); -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/search_task.h b/src/auto_schedule/search_task.h similarity index 93% rename from src/ansor/search_task.h rename to src/auto_schedule/search_task.h index fb9a6e098023..c7d2ddc533ed 100644 --- a/src/ansor/search_task.h +++ b/src/auto_schedule/search_task.h @@ -18,19 +18,19 @@ */ /*! - * \file ansor/search_task.h + * \file auto_schedule/search_task.h * \brief Meta information and hardware parameters for a search task. */ -#ifndef TVM_ANSOR_SEARCH_TASK_H_ -#define TVM_ANSOR_SEARCH_TASK_H_ +#ifndef TVM_AUTO_SCHEDULE_SEARCH_TASK_H_ +#define TVM_AUTO_SCHEDULE_SEARCH_TASK_H_ #include #include "compute_dag.h" namespace tvm { -namespace ansor { +namespace auto_schedule { class HardwareParams; @@ -76,7 +76,7 @@ class HardwareParamsNode : public Object { */ static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); - static constexpr const char* _type_key = "ansor.HardwareParams"; + static constexpr const char* _type_key = "auto_schedule.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; @@ -122,7 +122,7 @@ class SearchTaskNode : public Object { v->Visit("hardware_params", &hardware_params); } - static constexpr const char* _type_key = "ansor.SearchTask"; + static constexpr const char* _type_key = "auto_schedule.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; @@ -146,7 +146,7 @@ class SearchTask : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_SEARCH_TASK_H_ +#endif // TVM_AUTO_SCHEDULE_SEARCH_TASK_H_ diff --git a/src/ansor/transform_step.cc b/src/auto_schedule/transform_step.cc similarity index 98% rename from src/ansor/transform_step.cc rename to src/auto_schedule/transform_step.cc index f096e63a4e54..bffb2dcfab31 100644 --- a/src/ansor/transform_step.cc +++ b/src/auto_schedule/transform_step.cc @@ -18,7 +18,7 @@ */ /*! - * \file ansor/transform_step.cc + * \file auto_schedule/transform_step.cc * \brief Transformation steps. For each schedule primitive, there is a corresponding transform * step. */ @@ -34,7 +34,7 @@ #include "utils.h" namespace tvm { -namespace ansor { +namespace auto_schedule { /********** Reorder **********/ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { @@ -235,5 +235,5 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/auto_schedule/transform_step.h similarity index 92% rename from src/ansor/transform_step.h rename to src/auto_schedule/transform_step.h index 4feec4355c07..5e54c9de583f 100644 --- a/src/ansor/transform_step.h +++ b/src/auto_schedule/transform_step.h @@ -18,7 +18,7 @@ */ /*! - * \file ansor/transform_step.h + * \file auto_schedule/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform * step. The implementation of each step consists of 2 parts: * - transform_step.cc: How each step interacts with TE and TE's schedule primitives @@ -34,13 +34,13 @@ * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. - * 5. Add log record serialization support in `struct Handler>` + * 5. Add log record serialization support in `struct Handler>` * in `record.cc`. * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test. */ -#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ -#define TVM_ANSOR_TRANSFORM_STEP_H_ +#ifndef TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_ +#define TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_ #include #include @@ -49,7 +49,7 @@ #include "utils.h" namespace tvm { -namespace ansor { +namespace auto_schedule { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; @@ -62,7 +62,7 @@ class StepNode : public Object { /*! \brief The index of the stage. */ int stage_id; - static constexpr const char* _type_key = "ansor.Step"; + static constexpr const char* _type_key = "auto_schedule.Step"; TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); }; @@ -99,7 +99,7 @@ class ReorderStepNode : public StepNode { */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* _type_key = "ansor.ReorderStep"; + static constexpr const char* _type_key = "auto_schedule.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; @@ -154,7 +154,7 @@ class SplitStepNode : public StepNode { */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* _type_key = "ansor.SplitStep"; + static constexpr const char* _type_key = "auto_schedule.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; @@ -200,7 +200,7 @@ class FuseStepNode : public StepNode { */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* _type_key = "ansor.FuseStep"; + static constexpr const char* _type_key = "auto_schedule.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; @@ -220,7 +220,7 @@ class FuseStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_TRANSFORM_STEP_H_ +#endif // TVM_AUTO_SCHEDULE_TRANSFORM_STEP_H_ diff --git a/src/ansor/utils.cc b/src/auto_schedule/utils.cc similarity index 93% rename from src/ansor/utils.cc rename to src/auto_schedule/utils.cc index 93a8e2257604..7e69a38e8831 100644 --- a/src/ansor/utils.cc +++ b/src/auto_schedule/utils.cc @@ -18,14 +18,14 @@ */ /*! - * \file ansor/utils.cc + * \file auto_schedule/utils.cc * \brief Common utilities. */ #include "utils.h" namespace tvm { -namespace ansor { +namespace auto_schedule { NullStream& NullStream::Global() { static NullStream stream; @@ -51,5 +51,5 @@ ThreadPool& ThreadPool::Global() { return *pool; } -} // namespace ansor +} // namespace auto_schedule } // namespace tvm diff --git a/src/ansor/utils.h b/src/auto_schedule/utils.h similarity index 97% rename from src/ansor/utils.h rename to src/auto_schedule/utils.h index cd2d32344899..6993cee8402b 100644 --- a/src/ansor/utils.h +++ b/src/auto_schedule/utils.h @@ -18,12 +18,12 @@ */ /*! - * \file ansor/utils.h + * \file auto_schedule/utils.h * \brief Common utilities. */ -#ifndef TVM_ANSOR_UTILS_H_ -#define TVM_ANSOR_UTILS_H_ +#ifndef TVM_AUTO_SCHEDULE_UTILS_H_ +#define TVM_AUTO_SCHEDULE_UTILS_H_ #include #include @@ -61,7 +61,7 @@ struct hash> { } // namespace std namespace tvm { -namespace ansor { +namespace auto_schedule { /********** Utilities for Array, std::string **********/ /*! \brief Get the first appearance index of elements in an Array */ @@ -284,7 +284,7 @@ class ThreadPool { std::condition_variable finish_signal_; }; -} // namespace ansor +} // namespace auto_schedule } // namespace tvm -#endif // TVM_ANSOR_UTILS_H_ +#endif // TVM_AUTO_SCHEDULE_UTILS_H_ diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_auto_schedule_common.py similarity index 89% rename from tests/python/unittest/test_ansor_common.py rename to tests/python/unittest/test_auto_schedule_common.py index 9f4e62466095..691c7e767b6f 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_auto_schedule_common.py @@ -15,16 +15,16 @@ # specific language governing permissions and limitations # under the License. -"""Common functions for ansor test cases""" +"""Common functions for auto_schedule test cases""" import threading -from tvm import te, ansor +from tvm import te, auto_schedule import topi -@ansor.register_workload -def matmul_ansor_test(N, M, K): +@auto_schedule.register_workload +def matmul_auto_schedule_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') @@ -32,8 +32,8 @@ def matmul_ansor_test(N, M, K): return [A, B, C] -@ansor.register_workload("matmul_ansor_test_rename_1") -def matmul_ansor_test_rename_0(N, M, K): +@auto_schedule.register_workload("matmul_auto_schedule_test_rename_1") +def matmul_auto_schedule_test_rename_0(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') @@ -67,8 +67,8 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation def get_tiled_matmul(): - A, B, C = matmul_ansor_test(512, 512, 512) - dag = ansor.ComputeDAG([A, B, C]) + A, B, C = matmul_auto_schedule_test(512, 512, 512) + dag = auto_schedule.ComputeDAG([A, B, C]) s0 = dag.get_init_state() its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_auto_schedule_compute_dag.py similarity index 93% rename from tests/python/unittest/test_ansor_compute_dag.py rename to tests/python/unittest/test_auto_schedule_compute_dag.py index 934c13f158ef..8a4f836765eb 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_auto_schedule_compute_dag.py @@ -18,9 +18,9 @@ """Test ComputeDAG (replay, infer bound)""" import tvm -from tvm import ansor, te +from tvm import auto_schedule, te -from test_ansor_common import get_tiled_matmul +from test_auto_schedule_common import get_tiled_matmul def test_apply_steps(): diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_auto_schedule_loop_state.py similarity index 89% rename from tests/python/unittest/test_ansor_loop_state.py rename to tests/python/unittest/test_auto_schedule_loop_state.py index 35894354349f..ed54da513d16 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_auto_schedule_loop_state.py @@ -20,15 +20,15 @@ import numpy as np import tvm -from tvm import ansor, te +from tvm import auto_schedule, te import topi -from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu +from test_auto_schedule_common import matmul_auto_schedule_test, conv2d_nchw_bn_relu def test_split_fuse_reorder(): - A, B, C = matmul_ansor_test(512, 512, 512) - dag = ansor.ComputeDAG([A, B, C]) + A, B, C = matmul_auto_schedule_test(512, 512, 512) + dag = auto_schedule.ComputeDAG([A, B, C]) s0 = dag.get_init_state() i, j, k = s0[C].iters diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_auto_schedule_measure.py similarity index 74% rename from tests/python/unittest/test_ansor_measure.py rename to tests/python/unittest/test_auto_schedule_measure.py index 3820b7f0d168..52d016de0756 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_auto_schedule_measure.py @@ -18,10 +18,10 @@ """ Test measurement and log serialization. """ import tvm -from tvm import ansor +from tvm import auto_schedule import tempfile -from test_ansor_common import get_tiled_matmul +from test_auto_schedule_common import get_tiled_matmul def test_record(): @@ -30,15 +30,15 @@ def test_record(): if not tvm.runtime.enabled("llvm"): return target = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", target) + task = auto_schedule.SearchTask(dag, "test", target) - inp = ansor.measure.MeasureInput(task, s) - res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) + inp = auto_schedule.measure.MeasureInput(task, s) + res = auto_schedule.measure.MeasureResult([0.1], 0, "", 0.2, 1) with tempfile.NamedTemporaryFile() as fp: - ansor.save_records(fp.name, [inp], [res]) + auto_schedule.save_records(fp.name, [inp], [res]) - log_reader = ansor.RecordReader(fp.name) + log_reader = auto_schedule.RecordReader(fp.name) inputs, results = log_reader.read_lines() assert len(inputs) == 1 @@ -55,11 +55,11 @@ def test_measure_local_builder_runner(): if not tvm.runtime.enabled("llvm"): return tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) + task = auto_schedule.SearchTask(dag, "test", tgt) - minp = ansor.MeasureInput(task, s0) - local_builder = ansor.LocalBuilder() - local_runner = ansor.LocalRunner() + minp = auto_schedule.MeasureInput(task, s0) + local_builder = auto_schedule.LocalBuilder() + local_runner = auto_schedule.LocalRunner() bress = local_builder.build([minp]) assert bress[0].error_no == 0 diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_auto_schedule_search_policy.py similarity index 79% rename from tests/python/unittest/test_ansor_search_policy.py rename to tests/python/unittest/test_auto_schedule_search_policy.py index fd07c7aa48f4..9e08218dcbce 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_auto_schedule_search_policy.py @@ -22,31 +22,31 @@ import tempfile import tvm -from tvm import ansor +from tvm import auto_schedule -from test_ansor_common import matmul_ansor_test, PropagatingThread +from test_auto_schedule_common import matmul_auto_schedule_test, PropagatingThread -def search_common(workload=matmul_ansor_test, target="llvm", search_policy = ansor.EmptyPolicy(), +def search_common(workload=matmul_auto_schedule_test, target="llvm", search_policy = auto_schedule.EmptyPolicy(), seed=random.randint(1, 1 << 30), runner='local', cost_model=None, num_measure_trials=2, params=None, pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) N = 128 - workload_key = ansor.make_workload_key(workload, (N, N, N)) - dag = ansor.ComputeDAG(workload_key) + workload_key = auto_schedule.make_workload_key(workload, (N, N, N)) + dag = auto_schedule.ComputeDAG(workload_key) target = tvm.target.create(target) - task = ansor.SearchTask(dag, workload_key, target) + task = auto_schedule.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - tuning_options = ansor.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, + tuning_options = auto_schedule.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, verbose=0, - measure_callbacks=[ansor.RecordToFile(log_file)], + measure_callbacks=[auto_schedule.RecordToFile(log_file)], pre_search_callbacks=pre_search_callbacks) - sch, args = ansor.auto_schedule(task, search_policy, tuning_options) - inp, res = ansor.load_best(log_file, workload_key, target) + sch, args = auto_schedule.auto_schedule(task, search_policy, tuning_options) + inp, res = auto_schedule.load_best(log_file, workload_key, target) print("==== Python Code ====") print(dag.print_python_code_from_state(inp.state)) @@ -79,11 +79,11 @@ def test_workload_registry_search_basic(): t.start() t.join() t = PropagatingThread(target=search_common, - kwargs={'seed': 944563397, 'workload': "matmul_ansor_test"}) + kwargs={'seed': 944563397, 'workload': "matmul_auto_schedule_test"}) t.start() t.join() t = PropagatingThread(target=search_common, - kwargs={'seed': 944563397, 'workload': "matmul_ansor_test_rename_1"}) + kwargs={'seed': 944563397, 'workload': "matmul_auto_schedule_test_rename_1"}) t.start() t.join() From 6a61fb640be796d50358246d9c0fddafec6701d3 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 11 Jul 2020 22:51:57 +0800 Subject: [PATCH 73/78] Update --- python/tvm/auto_schedule/auto_schedule.py | 9 +++-- python/tvm/auto_schedule/measure.py | 34 +++++++++---------- python/tvm/auto_schedule/measure_record.py | 2 +- src/auto_schedule/auto_schedule.cc | 4 +-- src/auto_schedule/auto_schedule.h | 10 +++--- src/auto_schedule/measure.cc | 10 +++--- src/auto_schedule/measure.h | 30 ++++++++++------ .../search_policy/empty_policy.cc | 2 +- .../search_policy/empty_policy.h | 2 +- .../search_policy/search_policy.cc | 2 +- .../search_policy/search_policy.h | 8 ++--- src/auto_schedule/utils.h | 11 ++++-- 12 files changed, 69 insertions(+), 55 deletions(-) diff --git a/python/tvm/auto_schedule/auto_schedule.py b/python/tvm/auto_schedule/auto_schedule.py index 7b5f3b2ca593..8be6e99d6411 100644 --- a/python/tvm/auto_schedule/auto_schedule.py +++ b/python/tvm/auto_schedule/auto_schedule.py @@ -104,16 +104,15 @@ class TuningOptions(Object): The search policy measures `num_measure_trials` schedules in total and returns the best one among them. With `num_measure_trials` == 0, the policy will do the schedule search but won't involve - measurement. - This can be used to get a runnable schedule quickly without auto-tuning. + measurement. This can be used to get a runnable schedule quickly without auto-tuning. early_stopping: Optional[int] Stop the tuning early if getting no improvement after n measurements. num_measures_per_round: int = 64 The number of schedules to be measured at each search round. The whole schedule search process will try a total number of `num_measure_trials` in several rounds. - verbose: boolean = True - Verbosity level. False for silent, True to output information during schedule search. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during schedule search. builder: Union[ProgramBuilder, str] = 'local' ProgramBuilder which builds the program. runner: Union[ProgramRunner, str] = 'local' @@ -130,7 +129,7 @@ class TuningOptions(Object): TODO(jcf94): Add these implementation in later PRs. """ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64, - verbose=True, builder='local', runner='local', measure_callbacks=None, + verbose=1, builder='local', runner='local', measure_callbacks=None, pre_search_callbacks=None): if isinstance(builder, str): if builder == 'local': diff --git a/python/tvm/auto_schedule/measure.py b/python/tvm/auto_schedule/measure.py index 9c8292126317..24e2af1d8f49 100644 --- a/python/tvm/auto_schedule/measure.py +++ b/python/tvm/auto_schedule/measure.py @@ -128,15 +128,15 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): class ProgramBuilder(Object): """ The base class of ProgramBuilders. """ - def build(self, measure_inputs, verbose=True): + def build(self, measure_inputs, verbose=1): """ Build programs and return results. Parameters ---------- measure_inputs : List[MeasureInput] A List of MeasureInput. - verbose : boolean = True - Verbosity level. False for silent, True to output information during program building. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program building. Returns ------- @@ -149,7 +149,7 @@ def build(self, measure_inputs, verbose=True): class ProgramRunner(Object): """ The base class of ProgramRunners. """ - def run(self, measure_inputs, build_results, verbose=True): + def run(self, measure_inputs, build_results, verbose=1): """ Run measurement and return results. Parameters @@ -158,8 +158,8 @@ def run(self, measure_inputs, build_results, verbose=True): A List of MeasureInput. build_results : List[BuildResult] A List of BuildResult to be ran. - verbose : boolean = True - Verbosity level. False for silent, True to output information during program running. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program running. Returns ------- @@ -318,7 +318,7 @@ def timed_func(): else: filename = "" - if verbose: + if verbose >= 1: if error_no == MeasureErrorNo.NO_ERROR: print(".", end="") else: @@ -327,7 +327,7 @@ def timed_func(): res = call_func_with_timeout(timeout, timed_func) if isinstance(res, TimeoutError): - if verbose: + if verbose >= 1: print(".T", end="") # Build timeout res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout @@ -335,7 +335,7 @@ def timed_func(): @tvm._ffi.register_func("auto_schedule.local_builder.build") -def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=True): +def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbose=1): """ Build function of LocalBuilder to build the MeasureInputs to runnable modules. @@ -350,8 +350,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo Number of threads used to build in parallel. build_func : str = 'default' The name of build function to process the built module. - verbose : boolean = True - Verbosity level. False for silent, True to output information during program building. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program building. Returns ------- @@ -378,7 +378,7 @@ def local_builder_build(inputs, timeout, n_parallel, build_func='default', verbo @tvm._ffi.register_func("auto_schedule.local_runner.run") def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, - verbose=True): + verbose=1): """ Run function of LocalRunner to test the performance of the input BuildResults. @@ -409,8 +409,8 @@ def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, coo will be automatically increased. cooldown_interval : float = 0.0 The cool down interval between two measurements. - verbose : boolean = True - Verbosity level. False for silent, True to output information during program measuring. + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during program measuring. Returns ------- @@ -450,7 +450,7 @@ def timed_func(inp, build_res): toc = time.time() time.sleep(cooldown_interval) - if verbose: + if verbose >= 1: if error_no == MeasureErrorNo.NO_ERROR: print("*", end="") else: @@ -468,13 +468,13 @@ def timed_func(inp, build_res): res = call_func_with_timeout( timeout, timed_func, args=(inp, build_res)) if isinstance(res, TimeoutError): - if verbose: + if verbose >= 1: print("*T", end="") # Run timeout res = (max_float,), MeasureErrorNo.RUN_TIMEOUT, None, \ build_res.time_cost + timeout, time.time() measure_results.append(MeasureResult(*res)) - if verbose: + if verbose >= 1: print("") return measure_results diff --git a/python/tvm/auto_schedule/measure_record.py b/python/tvm/auto_schedule/measure_record.py index 5f97e5d737d4..25a998566280 100644 --- a/python/tvm/auto_schedule/measure_record.py +++ b/python/tvm/auto_schedule/measure_record.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Serialization and other I/O support for tuning logs (measurement records). """ +""" Serialization and other I/O support for measurement records (tuning logs). """ import numpy as np diff --git a/src/auto_schedule/auto_schedule.cc b/src/auto_schedule/auto_schedule.cc index 72484a93b5e6..aaf472b1f26a 100644 --- a/src/auto_schedule/auto_schedule.cc +++ b/src/auto_schedule/auto_schedule.cc @@ -32,7 +32,7 @@ namespace auto_schedule { TVM_REGISTER_NODE_TYPE(TuningOptionsNode); TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, - bool verbose, ProgramBuilder builder, ProgramRunner runner, + int verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, Optional> pre_search_callbacks) { auto node = make_object(); @@ -63,7 +63,7 @@ std::pair> AutoSchedule(SearchTask task, SearchP TVM_REGISTER_GLOBAL("auto_schedule.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, - bool verbose, ProgramBuilder builder, ProgramRunner runner, + int verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, Optional> pre_search_callbacks) { return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, diff --git a/src/auto_schedule/auto_schedule.h b/src/auto_schedule/auto_schedule.h index 05d3db6c28d7..d2e49bbe7e4f 100644 --- a/src/auto_schedule/auto_schedule.h +++ b/src/auto_schedule/auto_schedule.h @@ -46,9 +46,9 @@ class TuningOptionsNode : public Object { int num_measures_per_round; /*! * \brief Verbosity level. - * False for silent, true to output information during schedule searching. + * 0 for silent, 1 to output information during schedule searching. */ - bool verbose; + int verbose; /*! \brief ProgramBuilder which builds the program */ ProgramBuilder builder; /*! \brief ProgramRunner which runs the program and measure time costs */ @@ -84,15 +84,15 @@ class TuningOptions : public ObjectRef { * \param num_measure_trials Number of total measurement trials. * \param early_stopping Stops early the tuning if no improvement after n measurements. * \param num_measures_per_round The number of programs to be measured at each search round. - * \param verbose Verbosity level. False for silent, true to output information during schedule + * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule * search. * \param builder ProgramBuilder which builds the program. * \param runner ProgramRunner which runs the program and measure time costs. * \param measure_callbacks MeasureCallback functions to be called after each measure batch. * \param pre_search_callbacks SearchCallback functions to be called before schedule search. */ - TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, - bool verbose, ProgramBuilder builder, ProgramRunner runner, + TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, + ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks, Optional> pre_search_callbacks); diff --git a/src/auto_schedule/measure.cc b/src/auto_schedule/measure.cc index 86a72163c682..ef273ccaa039 100644 --- a/src/auto_schedule/measure.cc +++ b/src/auto_schedule/measure.cc @@ -110,7 +110,7 @@ LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func data_ = std::move(node); } -Array LocalBuilderNode::Build(const Array& inputs, bool verbose) { +Array LocalBuilderNode::Build(const Array& inputs, int verbose) { if (const auto* f = runtime::Registry::Get("auto_schedule.local_builder.build")) { Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; @@ -134,7 +134,7 @@ LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, } Array LocalRunnerNode::Run(const Array& inputs, - const Array& build_results, bool verbose) { + const Array& build_results, int verbose) { if (const auto* f = runtime::Registry::Get("auto_schedule.local_runner.run")) { Array results = (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, verbose); @@ -148,7 +148,7 @@ Array LocalRunnerNode::Run(const Array& inputs, /********** ProgramMeasurer **********/ ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, - Optional> callbacks, bool verbose, + Optional> callbacks, int verbose, int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); @@ -312,12 +312,12 @@ TVM_REGISTER_GLOBAL("auto_schedule.MeasureResult") TVM_REGISTER_GLOBAL("auto_schedule.ProgramBuilderBuild") .set_body_typed([](const ProgramBuilder& builder, const Array& inputs, - bool verbose) { return builder->Build(inputs, verbose); }); + int verbose) { return builder->Build(inputs, verbose); }); TVM_REGISTER_GLOBAL("auto_schedule.ProgramRunnerRun") .set_body_typed([](const ProgramRunner& runner, const Array& inputs, const Array& build_results, - bool verbose) { return runner->Run(inputs, build_results, verbose); }); + int verbose) { return runner->Run(inputs, build_results, verbose); }); TVM_REGISTER_GLOBAL("auto_schedule.LocalBuilder") .set_body_typed([](int timeout, int n_parallel, const String& build_func) { diff --git a/src/auto_schedule/measure.h b/src/auto_schedule/measure.h index c4f776abb003..a7890eaffd0d 100644 --- a/src/auto_schedule/measure.h +++ b/src/auto_schedule/measure.h @@ -20,7 +20,17 @@ /*! * \file auto_schedule/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. - * The flow of data structures is MeasureInput -> BuildeResult -> MeasureResult. + * These functions are responsible for building the tvm module, uploading it to remote devices, + * recording the running time costs, and checking the correctness of the output. + * + * We separate the measurement into two steps: build and run. + * A builder builds the executable binary files and a runner runs the binary files to get the + * measurement results. The flow of data structures is + * + * `ProgramBuilder` `ProgramRunner` + * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` + * + * We implement these in python to utilize python's multiprocessing and error handling. */ #ifndef TVM_AUTO_SCHEDULE_MEASURE_H_ @@ -232,11 +242,11 @@ class ProgramBuilderNode : public Object { /*! * \brief Build programs and return results. * \param inputs An Array of MeasureInput. - * \param verbose Verbosity level. False for silent, true to output information during program + * \param verbose Verbosity level. 0 for silent, 1 to output information during program * building. * \return An Array of MeasureResult. */ - virtual Array Build(const Array& inputs, bool verbose) = 0; + virtual Array Build(const Array& inputs, int verbose) = 0; static constexpr const char* _type_key = "auto_schedule.ProgramBuilder"; TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object); @@ -261,12 +271,12 @@ class ProgramRunnerNode : public Object { * \brief Run measurement and return results. * \param inputs An Array of MeasureInput. * \param build_results An Array of BuildResult. - * \param verbose Verbosity level. False for silent, true to output information during program + * \param verbose Verbosity level. 0 for silent, 1 to output information during program * running. * \return An Array of MeasureResult. */ virtual Array Run(const Array& inputs, - const Array& build_results, bool verbose) = 0; + const Array& build_results, int verbose) = 0; static constexpr const char* _type_key = "auto_schedule.ProgramRunner"; TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object); @@ -289,7 +299,7 @@ class LocalBuilderNode : public ProgramBuilderNode { /*! \brief Build function. */ String build_func; - Array Build(const Array& inputs, bool verbose) final; + Array Build(const Array& inputs, int verbose) final; static constexpr const char* _type_key = "auto_schedule.LocalBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode); @@ -326,7 +336,7 @@ class LocalRunnerNode : public ProgramRunnerNode { double cooldown_interval; Array Run(const Array& inputs, - const Array& build_results, bool verbose) final; + const Array& build_results, int verbose) final; static constexpr const char* _type_key = "auto_schedule.LocalRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode); @@ -375,7 +385,7 @@ class ProgramMeasurerNode : public Object { /*! \brief MeasureCallback to be called after each measure batch. */ Optional> callbacks; /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */ - bool verbose; + int verbose; /*! \brief The number of max continuous error. */ int max_continous_error; @@ -421,12 +431,12 @@ class ProgramMeasurer : public ObjectRef { * \param builder The ProgramBuilder to build each program. * \param runner The ProgramRunner to measure each program. * \param callbacks MeasureCallback to be called after each measure batch. - * \param verbose Verbosity level. False for silent, true to output information during program + * \param verbose Verbosity level. 0 for silent, 1 to output information during program * measuring. * \param max_continous_error The number of max continuous error. */ ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, - Optional> callbacks, bool verbose, + Optional> callbacks, int verbose, int max_continous_error = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); diff --git a/src/auto_schedule/search_policy/empty_policy.cc b/src/auto_schedule/search_policy/empty_policy.cc index 1bda033e24a1..d91d563252b5 100644 --- a/src/auto_schedule/search_policy/empty_policy.cc +++ b/src/auto_schedule/search_policy/empty_policy.cc @@ -34,7 +34,7 @@ namespace auto_schedule { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, bool verbose, ProgramMeasurer measurer, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) { cur_task = task; diff --git a/src/auto_schedule/search_policy/empty_policy.h b/src/auto_schedule/search_policy/empty_policy.h index a718b3d1de5f..610a02a3cd12 100644 --- a/src/auto_schedule/search_policy/empty_policy.h +++ b/src/auto_schedule/search_policy/empty_policy.h @@ -41,7 +41,7 @@ namespace auto_schedule { class EmptyPolicyNode : public SearchPolicyNode { public: State Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, bool verbose, ProgramMeasurer measurer, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) final; static constexpr const char* _type_key = "auto_schedule.EmptyPolicy"; diff --git a/src/auto_schedule/search_policy/search_policy.cc b/src/auto_schedule/search_policy/search_policy.cc index e2f977300f35..f8ac7ca39495 100644 --- a/src/auto_schedule/search_policy/search_policy.cc +++ b/src/auto_schedule/search_policy/search_policy.cc @@ -49,7 +49,7 @@ TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetTask") .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; }); TVM_REGISTER_GLOBAL("auto_schedule.SearchPolicySetVerbose") - .set_body_typed([](SearchPolicy policy, bool verbose) { policy->verbose = verbose; }); + .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; }); } // namespace auto_schedule } // namespace tvm diff --git a/src/auto_schedule/search_policy/search_policy.h b/src/auto_schedule/search_policy/search_policy.h index 224797282219..47cccec93661 100644 --- a/src/auto_schedule/search_policy/search_policy.h +++ b/src/auto_schedule/search_policy/search_policy.h @@ -99,9 +99,9 @@ class SearchPolicyNode : public Object { SearchTask cur_task; /*! * \brief Verbose level to control the screen output during schedule search. - * False for silent, true to output state & measure information during search process. + * 0 for silent, 1 to output state & measure information during search process. */ - bool verbose; + int verbose; void VisitAttrs(AttrVisitor* v) { v->Visit("cur_task", &cur_task); @@ -115,14 +115,14 @@ class SearchPolicyNode : public Object { * \param num_measure_trials Total schedules to be tried during this search. * \param early_stopping Early stop if no better schedule is found. * \param num_measures_per_round Max measure batch in one search round. - * \param verbose Verbose level. False for silent, true to output information during schedule + * \param verbose Verbose level. 0 for silent, 1 to output information during schedule * search. * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. * \param pre_search_callbacks SearchCallback to be called before schedule search. * \return The best state get. */ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, bool verbose, ProgramMeasurer measurer, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, Optional> pre_search_callbacks) = 0; /*! diff --git a/src/auto_schedule/utils.h b/src/auto_schedule/utils.h index 6993cee8402b..e0b6534203bd 100644 --- a/src/auto_schedule/utils.h +++ b/src/auto_schedule/utils.h @@ -163,7 +163,9 @@ NullStream& operator<<(NullStream& os, const T& value) { } /*! \brief Get std cout with verbose control */ -inline std::ostream& StdCout(bool verbose) { return verbose ? std::cout : NullStream::Global(); } +inline std::ostream& StdCout(int verbose, int setting = 1) { + return verbose >= setting ? std::cout : NullStream::Global(); +} /*! \brief Print multiple chars */ inline std::string Chars(const char& str, int times) { @@ -175,13 +177,16 @@ inline std::string Chars(const char& str, int times) { } /*! \brief Print a title */ -inline void PrintTitle(const std::string& title, bool verbose) { +inline void PrintTitle(const std::string& title, int verbose) { StdCout(verbose) << Chars('-', 60) << "\n" << Chars('-', 25) << " [ " << title << " ]\n" << Chars('-', 60) << std::endl; } -/*! \brief A simple thread pool */ +/*! + * \brief A simple thread pool. + * TODO(merrymercy): Move this to `src/support/parallel_for` + */ class ThreadPool { public: void Launch(size_t n = 1) { From 3a4e5dabe2fd55ff5e4727b4e90496940a80e8d3 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 14 Jul 2020 10:03:38 +0800 Subject: [PATCH 74/78] Rename ThreadPool to ParallelFor --- src/auto_schedule/measure.cc | 2 +- src/auto_schedule/utils.cc | 8 ++++---- src/auto_schedule/utils.h | 30 ++++++++++++++++++++++++++---- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/auto_schedule/measure.cc b/src/auto_schedule/measure.cc index ef273ccaa039..20d08bf9258d 100644 --- a/src/auto_schedule/measure.cc +++ b/src/auto_schedule/measure.cc @@ -237,7 +237,7 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array& inputs, Array* results) { // Close the thread pool to avoid the conflits with python environment - ThreadPool::Global().Abort(); + ParallelFor::Global().Abort(); results->clear(); results->reserve(inputs.size()); diff --git a/src/auto_schedule/utils.cc b/src/auto_schedule/utils.cc index 7e69a38e8831..2de5af151a90 100644 --- a/src/auto_schedule/utils.cc +++ b/src/auto_schedule/utils.cc @@ -32,16 +32,16 @@ NullStream& NullStream::Global() { return stream; } -ThreadPool& ThreadPool::Global() { - static ThreadPool* pool = new ThreadPool(); +ParallelFor& ParallelFor::Global() { + static ParallelFor* pool = new ParallelFor(); static int ct = 0; - ct = (ct + 1) % ThreadPool::REFRESH_EVERY; + ct = (ct + 1) % ParallelFor::REFRESH_EVERY; if (ct == 0) { pool->Abort(); delete pool; - pool = new ThreadPool(); + pool = new ParallelFor(); } if (pool->NumWorkers() == 0) { diff --git a/src/auto_schedule/utils.h b/src/auto_schedule/utils.h index e0b6534203bd..e226b7f851d2 100644 --- a/src/auto_schedule/utils.h +++ b/src/auto_schedule/utils.h @@ -184,22 +184,36 @@ inline void PrintTitle(const std::string& title, int verbose) { } /*! - * \brief A simple thread pool. + * \brief A simple thread pool to perform parallel for. * TODO(merrymercy): Move this to `src/support/parallel_for` */ -class ThreadPool { +class ParallelFor { public: + /*! + * \brief Set the thread number used in this pool. + * \param n The thread number of this pool. + */ void Launch(size_t n = 1) { for (std::size_t i = 0; i < n; ++i) { threads_.emplace_back([this] { WorkerFunc(); }); } } + /*! + * \brief Set the total task number to be executed in this parallel for run batch. + * \param n The task number of this parallel for run batch. + */ void BeginBatch(int n) { finish_ct_ = n; is_finished_ = n <= 0; } + /*! + * \brief Add run task to task queue. The task added will be run in thread pool immediately. + * \param f The task function to be executed. + * \param args The args of the task function. + * \return The result of the task function. + */ template ::type> std::future Enqueue(F&& f, Args&&... args) { std::packaged_task p(std::bind(f, args...)); @@ -213,6 +227,7 @@ class ThreadPool { return r; } + /*! \brief Wait until the parallel for run batch is finished. */ void WaitBatch() { std::unique_lock l(finish_mutex_); if (!is_finished_) { @@ -220,16 +235,19 @@ class ThreadPool { } } + /*! \brief Stop the running process. */ void Abort() { CancelPending(); Join(); } + /*! \brief Cancel all the tasks in task queue. */ void CancelPending() { std::unique_lock l(m_); work_.clear(); } + /*! \brief Wait until all of the threads are finished. */ void Join() { { std::unique_lock l(m_); @@ -244,12 +262,16 @@ class ThreadPool { threads_.clear(); } + /*! + * \brief Get the working thread number of this pool. + * \return The thread number of this pool. + */ size_t NumWorkers() { return threads_.size(); } static const int REFRESH_EVERY = 128; - static ThreadPool& Global(); + static ParallelFor& Global(); - ~ThreadPool() { Join(); } + ~ParallelFor() { Join(); } private: void WorkerFunc() { From dbe019bee3e9bace28eac6f7789f032ce2cde5a7 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 14 Jul 2020 11:02:02 +0800 Subject: [PATCH 75/78] Add parallel_for --- src/auto_schedule/measure.cc | 2 +- src/auto_schedule/utils.cc | 19 +++++++++++++++---- src/auto_schedule/utils.h | 20 +++++++++++++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/auto_schedule/measure.cc b/src/auto_schedule/measure.cc index 20d08bf9258d..ef273ccaa039 100644 --- a/src/auto_schedule/measure.cc +++ b/src/auto_schedule/measure.cc @@ -237,7 +237,7 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array& inputs, Array* results) { // Close the thread pool to avoid the conflits with python environment - ParallelFor::Global().Abort(); + ThreadPool::Global().Abort(); results->clear(); results->reserve(inputs.size()); diff --git a/src/auto_schedule/utils.cc b/src/auto_schedule/utils.cc index 2de5af151a90..d2901fa401ef 100644 --- a/src/auto_schedule/utils.cc +++ b/src/auto_schedule/utils.cc @@ -32,16 +32,16 @@ NullStream& NullStream::Global() { return stream; } -ParallelFor& ParallelFor::Global() { - static ParallelFor* pool = new ParallelFor(); +ThreadPool& ThreadPool::Global() { + static ThreadPool* pool = new ThreadPool(); static int ct = 0; - ct = (ct + 1) % ParallelFor::REFRESH_EVERY; + ct = (ct + 1) % ThreadPool::REFRESH_EVERY; if (ct == 0) { pool->Abort(); delete pool; - pool = new ParallelFor(); + pool = new ThreadPool(); } if (pool->NumWorkers() == 0) { @@ -51,5 +51,16 @@ ParallelFor& ParallelFor::Global() { return *pool; } +void parallel_for(int start, int end, std::function f, int stride) { + auto& pf = ThreadPool::Global(); + int batch_count = (end - start) / stride; + CHECK_GT(batch_count, 0); + pf.BeginBatch(batch_count); + for (int i = start; i < end; i += stride) { + pf.Enqueue(f, i); + } + pf.WaitBatch(); +} + } // namespace auto_schedule } // namespace tvm diff --git a/src/auto_schedule/utils.h b/src/auto_schedule/utils.h index e226b7f851d2..a554f5591aed 100644 --- a/src/auto_schedule/utils.h +++ b/src/auto_schedule/utils.h @@ -187,7 +187,7 @@ inline void PrintTitle(const std::string& title, int verbose) { * \brief A simple thread pool to perform parallel for. * TODO(merrymercy): Move this to `src/support/parallel_for` */ -class ParallelFor { +class ThreadPool { public: /*! * \brief Set the thread number used in this pool. @@ -269,9 +269,9 @@ class ParallelFor { size_t NumWorkers() { return threads_.size(); } static const int REFRESH_EVERY = 128; - static ParallelFor& Global(); + static ThreadPool& Global(); - ~ParallelFor() { Join(); } + ~ThreadPool() { Join(); } private: void WorkerFunc() { @@ -311,6 +311,20 @@ class ParallelFor { std::condition_variable finish_signal_; }; +/*! + * \brief A runtime api provided to run the task function in parallel. + * TODO(merrymercy): Move this to `src/support/parallel_for` + * Example: + * parallel_for(1, 11, [](int index) { + * std::cout << index << "\n"; + * }); + * \param start The start index. + * \param end The end index. + * \param f The task function to be excuted. Assert to take an int index as input with no output. + * \param stride The stride of the index. + */ +void parallel_for(int start, int end, std::function f, int stride = 1); + } // namespace auto_schedule } // namespace tvm From 1f1b8785be0b547e6483465e8cb3cc0fba3cfb4a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 14 Jul 2020 14:47:51 +0800 Subject: [PATCH 76/78] Remove ThreadPool --- src/auto_schedule/measure.cc | 3 - src/auto_schedule/utils.cc | 30 -------- src/auto_schedule/utils.h | 142 ----------------------------------- 3 files changed, 175 deletions(-) diff --git a/src/auto_schedule/measure.cc b/src/auto_schedule/measure.cc index ef273ccaa039..b710745b02f9 100644 --- a/src/auto_schedule/measure.cc +++ b/src/auto_schedule/measure.cc @@ -236,9 +236,6 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& po void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array& inputs, Array* results) { - // Close the thread pool to avoid the conflits with python environment - ThreadPool::Global().Abort(); - results->clear(); results->reserve(inputs.size()); diff --git a/src/auto_schedule/utils.cc b/src/auto_schedule/utils.cc index d2901fa401ef..ecb6145268d6 100644 --- a/src/auto_schedule/utils.cc +++ b/src/auto_schedule/utils.cc @@ -32,35 +32,5 @@ NullStream& NullStream::Global() { return stream; } -ThreadPool& ThreadPool::Global() { - static ThreadPool* pool = new ThreadPool(); - static int ct = 0; - - ct = (ct + 1) % ThreadPool::REFRESH_EVERY; - - if (ct == 0) { - pool->Abort(); - delete pool; - pool = new ThreadPool(); - } - - if (pool->NumWorkers() == 0) { - pool->Launch(std::thread::hardware_concurrency()); - } - - return *pool; -} - -void parallel_for(int start, int end, std::function f, int stride) { - auto& pf = ThreadPool::Global(); - int batch_count = (end - start) / stride; - CHECK_GT(batch_count, 0); - pf.BeginBatch(batch_count); - for (int i = start; i < end; i += stride) { - pf.Enqueue(f, i); - } - pf.WaitBatch(); -} - } // namespace auto_schedule } // namespace tvm diff --git a/src/auto_schedule/utils.h b/src/auto_schedule/utils.h index a554f5591aed..e91bc106fb51 100644 --- a/src/auto_schedule/utils.h +++ b/src/auto_schedule/utils.h @@ -183,148 +183,6 @@ inline void PrintTitle(const std::string& title, int verbose) { << Chars('-', 60) << std::endl; } -/*! - * \brief A simple thread pool to perform parallel for. - * TODO(merrymercy): Move this to `src/support/parallel_for` - */ -class ThreadPool { - public: - /*! - * \brief Set the thread number used in this pool. - * \param n The thread number of this pool. - */ - void Launch(size_t n = 1) { - for (std::size_t i = 0; i < n; ++i) { - threads_.emplace_back([this] { WorkerFunc(); }); - } - } - - /*! - * \brief Set the total task number to be executed in this parallel for run batch. - * \param n The task number of this parallel for run batch. - */ - void BeginBatch(int n) { - finish_ct_ = n; - is_finished_ = n <= 0; - } - - /*! - * \brief Add run task to task queue. The task added will be run in thread pool immediately. - * \param f The task function to be executed. - * \param args The args of the task function. - * \return The result of the task function. - */ - template ::type> - std::future Enqueue(F&& f, Args&&... args) { - std::packaged_task p(std::bind(f, args...)); - - auto r = p.get_future(); - { - std::unique_lock l(m_); - work_.emplace_back(std::move(p)); - } - work_signal_.notify_one(); - return r; - } - - /*! \brief Wait until the parallel for run batch is finished. */ - void WaitBatch() { - std::unique_lock l(finish_mutex_); - if (!is_finished_) { - finish_signal_.wait(l); - } - } - - /*! \brief Stop the running process. */ - void Abort() { - CancelPending(); - Join(); - } - - /*! \brief Cancel all the tasks in task queue. */ - void CancelPending() { - std::unique_lock l(m_); - work_.clear(); - } - - /*! \brief Wait until all of the threads are finished. */ - void Join() { - { - std::unique_lock l(m_); - for (size_t i = 0; i < threads_.size(); ++i) { - work_.push_back({}); - } - } - work_signal_.notify_all(); - for (auto& t : threads_) { - t.join(); - } - threads_.clear(); - } - - /*! - * \brief Get the working thread number of this pool. - * \return The thread number of this pool. - */ - size_t NumWorkers() { return threads_.size(); } - - static const int REFRESH_EVERY = 128; - static ThreadPool& Global(); - - ~ThreadPool() { Join(); } - - private: - void WorkerFunc() { - while (true) { - std::packaged_task f; - { - std::unique_lock l(m_); - if (work_.empty()) { - work_signal_.wait(l, [&] { return !work_.empty(); }); - } - f = std::move(work_.front()); - work_.pop_front(); - } - if (!f.valid()) { - return; - } - f(); - - finish_ct_--; - if (finish_ct_ == 0) { - std::unique_lock l(finish_mutex_); - - is_finished_ = true; - finish_signal_.notify_one(); - } - } - } - - std::mutex m_; - std::condition_variable work_signal_; - std::deque> work_; - std::vector threads_; - - bool is_finished_; - std::mutex finish_mutex_; - std::atomic finish_ct_; - std::condition_variable finish_signal_; -}; - -/*! - * \brief A runtime api provided to run the task function in parallel. - * TODO(merrymercy): Move this to `src/support/parallel_for` - * Example: - * parallel_for(1, 11, [](int index) { - * std::cout << index << "\n"; - * }); - * \param start The start index. - * \param end The end index. - * \param f The task function to be excuted. Assert to take an int index as input with no output. - * \param stride The stride of the index. - */ -void parallel_for(int start, int end, std::function f, int stride = 1); - } // namespace auto_schedule } // namespace tvm From 02fede9d76b53b265a40efa83cb2fd5565cce4e3 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 14 Jul 2020 10:02:56 -0700 Subject: [PATCH 77/78] Update python/tvm/auto_schedule/auto_schedule.py --- python/tvm/auto_schedule/auto_schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_schedule/auto_schedule.py b/python/tvm/auto_schedule/auto_schedule.py index 8be6e99d6411..8311b85febd6 100644 --- a/python/tvm/auto_schedule/auto_schedule.py +++ b/python/tvm/auto_schedule/auto_schedule.py @@ -36,7 +36,7 @@ @tvm._ffi.register_object("auto_schedule.HardwareParams") class HardwareParams(Object): - """ The parameters of target hardware used to guide the search process of SearchPolicy. + """ The parameters of target hardware used to guide the search policy TODO(jcf94): This is considered to be merged with the new Target: https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 From eea098908ba5260cf151f94708e8b3ec717b8f32 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 14 Jul 2020 11:27:45 -0700 Subject: [PATCH 78/78] trigger CI --- python/tvm/auto_schedule/auto_schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_schedule/auto_schedule.py b/python/tvm/auto_schedule/auto_schedule.py index 8311b85febd6..ffbfc3c914ff 100644 --- a/python/tvm/auto_schedule/auto_schedule.py +++ b/python/tvm/auto_schedule/auto_schedule.py @@ -38,7 +38,7 @@ class HardwareParams(Object): """ The parameters of target hardware used to guide the search policy - TODO(jcf94): This is considered to be merged with the new Target: + TODO(jcf94): This is considered to be merged with the new Target specification: https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 Parameters