From c79dc79b8d5921260b43993a8b41eb0841626914 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sat, 4 Apr 2020 13:04:04 +0800 Subject: [PATCH] Relay i64 support --- include/tvm/te/operation.h | 7 ++ python/tvm/relay/backend/compile_engine.py | 3 +- python/tvm/te/operation.py | 6 + src/relay/backend/compile_engine.cc | 7 +- src/te/operation/compute_op.cc | 8 +- src/te/operation/op_util.cc | 49 ++++++++ src/te/operation/placeholder_op.cc | 2 + src/tir/transforms/vectorize_loop.cc | 17 ++- tests/python/relay/test_pass_fuse_ops.py | 115 ++++++++++++++++++ .../test_tir_transform_narrow_datatype.py | 83 +++++++++++++ topi/include/topi/detail/constant_utils.h | 7 +- 11 files changed, 289 insertions(+), 15 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 739ea8599179c..e3a16f1b78ab1 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -578,6 +578,13 @@ inline Tensor compute(Array shape, std::function() const { return static_cast(get()); } + +/*! + * \brief Converts IntImm in shape to to DataType::Int(64) if necessary + * \param shape The shape to be converted + */ +TVM_DLL Array GetShape(Array shape); + } // namespace te } // namespace tvm #endif // TVM_TE_OPERATION_H_ diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 3e35bd22e08f5..bcb6e283c3964 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -82,8 +82,7 @@ def get_shape(shape): for dim in shape: if isinstance(dim, tvm.tir.IntImm): val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) + ret.append(val) elif isinstance(dim, tvm.tir.Any): ret.append(te.var("any_dim", "int32")) else: diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 3ccab5bfd9c30..dbc0ce1937b37 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -50,6 +50,7 @@ def placeholder(shape, dtype=None, name="placeholder"): The created tensor """ shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape + shape = _ffi_api.GetShape(shape) dtype = "float32" if dtype is None else dtype return _ffi_api.Placeholder( shape, dtype, name) @@ -89,6 +90,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) + shape = _ffi_api.GetShape(shape) ndim = len(shape) code = fcompute.__code__ @@ -288,6 +290,10 @@ def extern(shape, if len(shape) != len(out_buffers): raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d." % (len(shape), len(out_buffers))) + promoted_shape = [] + for shp in shape: + promoted_shape.append(_ffi_api.GetShape(shp)) + shape = promoted_shape input_placeholders = in_buffers or [] output_placeholders = out_buffers or [] types = set() diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 182207ad7bb59..611e0e58d4e87 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -78,9 +78,10 @@ Array GetShape(const Array& shape) { for (IndexExpr val : shape) { const int64_t* pval = tir::as_const_int(val); if (pval != nullptr) { - CHECK_LE(pval[0], std::numeric_limits::max()); - CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); + // CHECK_LE(pval[0], std::numeric_limits::max()); + // CHECK_GE(pval[0], std::numeric_limits::min()); + // res.push_back(IntImm(DataType::Int(32), *pval)); + res.push_back(val); } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index d8ad839e777eb..a1f76e5a4e599 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -93,11 +93,13 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std:: size_t ndim = shape.size(); std::vector axis; std::vector args; + shape = GetShape(shape); for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; axis.emplace_back( - IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + IterVarNode::make(Range(IntImm(shape[i].dtype(), 0), shape[i]), + Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -111,11 +113,13 @@ Array compute(Array shape, FBatchCompute fcompute, std::string size_t ndim = shape.size(); std::vector axis; std::vector args; + shape = GetShape(shape); for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; axis.emplace_back( - IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + IterVarNode::make(Range(IntImm(shape[i].dtype(), 0), shape[i]), + Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 5b200ac0ce940..8700c342cdbdd 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -23,11 +23,14 @@ */ #include "op_util.h" +#include #include #include #include +#include #include +#include #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -39,6 +42,23 @@ namespace te { using namespace arith; using namespace tir; +Range RangeMatchTypes(Range dom) { + PrimExpr a = dom->min; + PrimExpr b = dom->extent; + if (a.dtype() == b.dtype()) return dom; + DataType atype = a.dtype(); + DataType btype = b.dtype(); + // Only do int type promotion + CHECK(atype.is_scalar()); + CHECK(btype.is_scalar()); + CHECK(atype.code() == btype.code()); + int bits = std::max(atype.bits(), btype.bits()); + DataType dtype = atype.with_bits(bits); + a = cast(dtype, a); + b = cast(dtype, b); + return Range::make_by_min_extent(a, b); +} + std::vector > MakeLoopNest(const Stage& stage, const std::unordered_map& dom_map, size_t begin_iter_pos, bool new_loop_var, @@ -71,6 +91,9 @@ std::vector > MakeLoopNest(const Stage& stage, // initialize the offset and loop_level Var var = bind_iv->var; + // Match the type of dom + dom = RangeMatchTypes(dom); + // Mark the iter var in the IR, to remember the point if (bind_iv->thread_tag.length() == 0) { // Only generate new loop if we're not bound to a thread. @@ -277,5 +300,31 @@ tir::ForType IterVarTypeToForType(IterVarType iter_type) { } } +Array GetShape(Array shape) { + bool is_const = true; + int64_t size = 1; + DataType dtype; + for (auto s : shape) { + if (const IntImmNode* i = s.as()) { + size *= i->value; + } else { + is_const = false; + dtype = s.dtype(); + } + } + Array ret; + if (is_const && size > std::numeric_limits::max()) { + for (auto s : shape) { + int64_t value = Downcast(s)->value; + ret.push_back(IntImm(DataType::Int(64), value)); + } + } else { + ret = shape; + } + return ret; +} + +TVM_REGISTER_GLOBAL("te.GetShape").set_body_typed(GetShape); + } // namespace te } // namespace tvm diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 9c536ebb87859..3c4902fb0aa34 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -23,6 +23,7 @@ */ #include #include +#include "op_util.h" namespace tvm { namespace te { @@ -59,6 +60,7 @@ Operation PlaceholderOpNode::make(std::string name, Array shape, DataT } Tensor placeholder(Array shape, DataType dtype, std::string name) { + shape = GetShape(shape); return PlaceholderOpNode::make(name, shape, dtype).output(0); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 9e553cb12ceb9..92ee0dbb03719 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -94,8 +94,10 @@ class VecAllocAccess : public StmtExprMutator { class Vectorizer : public StmtExprMutator { public: - Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { - ramp_ = RampNode::make(0, 1, var_lanes); + Vectorizer(Var var, IntImm var_lanes) : var_(var), var_lanes_(var_lanes) { + ramp_ = RampNode::make(IntImm(var_lanes.dtype(), 0), + IntImm(var_lanes.dtype(), 1), + var_lanes->value); } Stmt VisitStmt(const Stmt& stmt) final { @@ -363,7 +365,9 @@ class Vectorizer : public StmtExprMutator { // place the vector lanes in least significant dimension. extents.push_back(var_lanes_); // rewrite access to buffer internally. - Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); + Stmt body = VecAllocAccess(op->buffer_var.get(), + var_, + static_cast(var_lanes_->value))(op->body); body = this->VisitStmt(body); return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); } @@ -372,7 +376,8 @@ class Vectorizer : public StmtExprMutator { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); + return ForNode::make(idx, make_zero(var_lanes_.dtype()), + var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } private: @@ -381,7 +386,7 @@ class Vectorizer : public StmtExprMutator { // variable to be replaced Var var_; // the lanes. - int var_lanes_; + IntImm var_lanes_; // ramp representing the var. PrimExpr ramp_; // flag to mark requirment of scalarization. @@ -457,7 +462,7 @@ class LoopVectorizer : public StmtMutator { if (!extent_as_int || extent_as_int->value < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } - return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); + return Vectorizer(op->loop_var, GetRef(extent_as_int))(op->body); } else { return StmtMutator::VisitStmt_(op); } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 6b7d297541c75..a054728e0349d 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -19,6 +19,7 @@ from tvm import relay from tvm.relay import transform from tvm.relay.testing import run_opt_pass +import numpy as np def test_fuse_simple(): @@ -621,6 +622,117 @@ def expected(): after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after) + +def test_fuse_strided_slice(): + """Test fusion case involving concat and strided_slice""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)]) + t = relay.Function(relay.analysis.free_vars(out), out) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + p0 = relay.var("p0", shape=shape) + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)]) + + f0 = relay.Function([p0], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x]) + return relay.Function([x], y) + orig = before() + fuse0(tvm.IRModule.from_expr(orig)) + t = tvm.IRModule.from_expr(orig) + m = fuse2(tvm.IRModule.from_expr(orig)) + attention = m["main"].body.op.params + + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + +def test_fuse_take(): + """Test fusion case involving concat and take""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.op.take(concat, indices=relay.const([0], dtype="int64")) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape1 = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + shape2 = (tvm.tir.const(1, "int64"),) + x = relay.var("x", shape=shape1) + p0 = relay.var("p0", shape=shape1) + p1 = relay.var("p1", shape=shape2, + dtype="int64") + c = relay.const([0], dtype="int64") + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.op.take(concat, indices=p1) + + f0 = relay.Function([p0, p1], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x, c]) + return relay.Function([x], y) + + orig = before() + fuse0(tvm.IRModule.from_expr(orig)) + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + +def test_fuse_gather_nd(): + """Test fusion case involving concat and gather_nd""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64")) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape1 = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + shape2 = (tvm.tir.const(2, "int64"), + tvm.tir.const(2, "int64")) + x = relay.var("x", shape=shape1) + p0 = relay.var("p0", shape=shape1) + p1 = relay.var("p1", shape=shape2, dtype="int64") + c = relay.const([[0,1],[1,0]], dtype="int64") + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.gather_nd(concat, indices=p1) + + f0 = relay.Function([p0, p1], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x, c]) + return relay.Function([x], y) + + orig = before() + fuse0(tvm.IRModule.from_expr(orig)) + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -637,3 +749,6 @@ def expected(): test_immutable() test_split() test_fuse_max() + test_fuse_strided_slice() + test_fuse_take() + test_fuse_gather_nd() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 6179bbbfbd07c..026b087af79b8 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm import relay from tvm.tir import const @@ -38,6 +39,7 @@ def lower_sch(sch, args, target_bits): arg_list.append(buf) else: raise ValueError("args must be Tensor, Buffer or Var") + sch = sch.normalize() bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) @@ -189,9 +191,90 @@ def check(m, n, target_bits, target_dtype): target_bits=32, target_dtype='int64') +def test_basic_from_relay(): + engine = relay.backend.compile_engine.get() + def check(shapex, shapey, target_bits, target_dtype): + x = relay.var('x', shape=shapex) + y = relay.var('y', shape=shapey) + z = relay.add(x, y) + func = relay.Function([x, y], z) + mod = tvm.IRModule.from_expr(func) + func = mod["main"] + z = engine.lower(func, "llvm") + stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + # outer loop + assert stmt.loop_var.dtype == target_dtype + # inner loop + if len(shapex) > 1 or len(shapey) > 1: + assert stmt.body.loop_var.dtype == target_dtype + + check((65536, 32769), (1, 32769), + target_bits=32, target_dtype="int64") + check((65536, 32768), (1, 32768), + target_bits=32, target_dtype="int32") + check((2**31,), (2**31,), + target_bits=32, target_dtype="int32") + check((2**31 + 1,), (2**31 + 1,), + target_bits=32, target_dtype="int64") + + +def test_te_extern(): + def check(shape, target_bits, target_dtype): + A = te.placeholder(shape, name='A') + B = te.placeholder(shape, name='B') + def add(A, B, C): + m, n = A.shape + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + with ib.for_range(0, m, name="i") as i: + with ib.for_range(0, n, name="j") as j: + Cptr[i * n + j] = Aptr[i * n + j] + Bptr[i * n + j] + body = ib.get() + return body + C = te.extern(shape, [A, B], lambda ins, outs: add(ins[0], ins[1], outs[0]), + name="add") + s = te.create_schedule(C.op) + stmt = lower_sch(s, (A, B, C), 32) + t = stmt + assert stmt.body.loop_var.dtype == target_dtype + assert stmt.body.body.loop_var.dtype == target_dtype + + check((2**15, 2**16), + target_bits=32, target_dtype="int32") + check((2**15, 2**16 + 1), + target_bits=32, target_dtype="int64") + + +def test_te_scan(): + def check(shape, target_bits, init_dtype, upd_dtype): + m, n = shape + x = te.placeholder(shape, name='x') + s = te.placeholder(shape, name='s') + res = tvm.te.scan(te.compute((1, n), lambda _, i: x[0, i]), + te.compute((m, n), lambda t, i: s[t - 1, i] + x[t, i]), + s) + s = te.create_schedule(res.op) + stmt = lower_sch(s, (x, res), 32) + # check init + assert stmt[0].loop_var.dtype == init_dtype + #check update + assert stmt[1].loop_var.dtype == upd_dtype + assert stmt[1].body.loop_var.dtype == upd_dtype + + check((2**15, 2**16), + target_bits=32, init_dtype="int32", upd_dtype="int32") + check((2**15, 2**16 + 1), + target_bits=32, init_dtype="int32", upd_dtype="int64") + + if __name__ == "__main__": test_basic() test_thread_axis() test_multilanes() test_reduce() test_slice() + test_basic_from_relay() + test_te_extern() + test_te_scan() diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 9bd1251199878..de6cf8ed818ce 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -114,8 +114,11 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { tvm::tir::ExprDeepEqual expr_equal; bool result = expr_equal(lhs, rhs); if (!result) { - PrimExpr zero(0); - result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero); + PrimExpr ret = tvm::arith::Analyzer().Simplify(lhs - rhs); + if (const IntImmNode* v = ret.as()) { + return v->value == 0; + } + return false; } return result; }