diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index b4527f76e808..6fa5459829fc 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -10,6 +10,7 @@ #include "./op_util.h" #include "./compute_op.h" #include "../schedule/message_passing.h" +#include "../arithmetic/compute_expr.h" namespace tvm { @@ -322,6 +323,50 @@ void VerifyTensorizeBody( } } +/*! + * \brief Transform the update part when there is no init func in tensorizing + * \param stage The stage for tensorizing. + * \param dom_map The range of each iter var. + * \param n The loop nest structured used in compute. + * \param body The body func in tensorize intrin + * \param update The update func in tensorize intrin + * \return Transformed result. + */ +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update) { + Array conds; + std::unordered_set banned; + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + auto iit = stage->iter_var_attrs.find(iv); + if (iit != stage->iter_var_attrs.end()) { + const IterVarAttr& attr = (*iit).second; + if (attr->iter_type == kTensorized) { + break; + } + } + if (iv->iter_type == kCommReduce) { + auto vit = dom_map.find(iv); + CHECK(vit != dom_map.end()); + const Range& vrange = vit->second; + conds.push_back(likely(iv->var > vrange->min)); + banned.insert(iv->var.get()); + } + } + for (const Expr& pred : n.main_predicates) { + if (ir::ExprUseVar(pred, banned)) { + LOG(FATAL) << "Tensorize update transform failed, the condition " + << pred << " has a conflict with the reset condition"; + } + } + + return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), + update, body); +} + Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map) { @@ -416,32 +461,47 @@ Stmt MakeTensorize(const ComputeOpNode* self, return MergeNest(nest, body); } else { // Need to split reduction - CHECK(intrin->reduce_init.defined()) - << "Reduction init op for intrin " << intrin << " is not defined"; CHECK(intrin->reduce_update.defined()) << "Reduction update op for intrin " << intrin << " is not defined"; // Need init and update steps CHECK_NE(self->reduce_axis.size(), 0U); std::vector > common( n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); - init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); - Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); - init = Substitute(init, n.init_vmap); - init = MergeNest(init_nest, init); - // The update std::vector > update_nest( n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); - Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); - update = MergeNest(input_bind_nest, update); - update = Substitute(update, vmap); - update = MergeNest(binder.asserts(), update); - update = Substitute(update, n.main_vmap); - update = MergeNest(update_nest, update); - return MergeNest(common, Block::make(init, update)); + + if (intrin->reduce_init.defined()) { + // init nest + std::vector > init_nest( + n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); + Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); + init = Substitute(init, n.init_vmap); + init = MergeNest(init_nest, init); + // The update + Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); + update = MergeNest(input_bind_nest, update); + update = Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, Block::make(init, update)); + } else { + // When init op is not available, use body op for reset in the first iter. + CHECK(intrin->body.defined()) + << "Normal body op for intrin " << intrin << " is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, + intrin->body, + intrin->reduce_update); + update = MergeNest(output_bind_nest, update); + update = MergeNest(input_bind_nest, update); + update = Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, update); + } } } diff --git a/tests/python/unittest/test_schedule_tensorize_init_none.py b/tests/python/unittest/test_schedule_tensorize_init_none.py new file mode 100644 index 000000000000..ce1d5633173a --- /dev/null +++ b/tests/python/unittest/test_schedule_tensorize_init_none.py @@ -0,0 +1,90 @@ +import tvm + +def intrin_gemv(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, None, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + + +def test_tensorize_matmul(): + n = 1024 + m = n + l = n + A = tvm.placeholder((n, l), name='A') + B = tvm.placeholder((m, l), name='B') + k = tvm.reduce_axis((0, l), name='k') + C = tvm.compute((n, m), lambda i, j: + tvm.sum(B[j, k] * A[i, k], axis=k), name='C') + + def check(factor): + s = tvm.create_schedule(C.op) + x, y = C.op.axis + yo, yi = s[C].split(y, factor=factor) + gemv = intrin_gemv(factor, l) + s[C].tensorize(yi, gemv) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[C], dom_map) + assert tvm.ir_pass.Equal(out_dom[x].extent, 1) + assert tvm.ir_pass.Equal(out_dom[y].extent, factor) + assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[C], out_dom, in_dom, gemv) + assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), + tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [A, B, C]) + + + def check_rfactor(factor, rfactor): + s = tvm.create_schedule(C.op) + x, y = C.op.axis + rk = C.op.reduce_axis[0] + yo, yi = s[C].split(y, factor=factor) + ro, ri = s[C].split(rk, factor=rfactor) + s[C].reorder(yo, ro, yi, ri) + gemv = intrin_gemv(factor, rfactor) + s[C].tensorize(yi, gemv) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[C], dom_map) + assert tvm.ir_pass.Equal(out_dom[x].extent, 1) + assert tvm.ir_pass.Equal(out_dom[y].extent, factor) + assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[C], out_dom, in_dom, gemv) + assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), + tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [A, B, C]) + + check(16) + check_rfactor(16, 16) + + +if __name__ == "__main__": + test_tensorize_matmul()