diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index aab30ede5dbfa..7399af90f5709 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -63,7 +63,8 @@ class OperationInliner final : public StmtExprMutator { } else { Map vmap; for (size_t i = 0; i < args_.size(); ++i) { - vmap.Set(args_[i], op->indices[i]); + // cast indices to the type of the original indexing variable + vmap.Set(args_[i], cast(args_[i].dtype(), op->indices[i])); } expr = Substitute(Evaluate(expr), vmap).as()->value; } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 6b7d297541c75..f4369c1f1d904 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -621,6 +621,79 @@ def expected(): after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after) + +def test_fuse_take(): + """Test fusion case involving concat and take""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.op.take(concat, indices=relay.const([0], dtype="int64")) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape1 = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + shape2 = (tvm.tir.const(1, "int64"),) + x = relay.var("x", shape=shape1) + p0 = relay.var("p0", shape=shape1) + p1 = relay.var("p1", shape=shape2, + dtype="int64") + c = relay.const([0], dtype="int64") + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.op.take(concat, indices=p1) + + f0 = relay.Function([p0, p1], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x, c]) + return relay.Function([x], y) + + orig = before() + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + +def test_fuse_gather_nd(): + """Test fusion case involving concat and gather_nd""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64")) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape1 = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + shape2 = (tvm.tir.const(2, "int64"), + tvm.tir.const(2, "int64")) + x = relay.var("x", shape=shape1) + p0 = relay.var("p0", shape=shape1) + p1 = relay.var("p1", shape=shape2, dtype="int64") + c = relay.const([[0,1],[1,0]], dtype="int64") + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.gather_nd(concat, indices=p1) + + f0 = relay.Function([p0, p1], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x, c]) + return relay.Function([x], y) + + orig = before() + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -637,3 +710,5 @@ def expected(): test_immutable() test_split() test_fuse_max() + test_fuse_take() + test_fuse_gather_nd()