Skip to content

Commit

Permalink
Relay i64 support
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Jun 13, 2020
1 parent 162a29e commit c79dc79
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 15 deletions.
7 changes: 7 additions & 0 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,13 @@ inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Va
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(get());
}

/*!
* \brief Converts IntImm in shape to to DataType::Int(64) if necessary
* \param shape The shape to be converted
*/
TVM_DLL Array<PrimExpr> GetShape(Array<PrimExpr> shape);

} // namespace te
} // namespace tvm
#endif // TVM_TE_OPERATION_H_
3 changes: 1 addition & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def get_shape(shape):
for dim in shape:
if isinstance(dim, tvm.tir.IntImm):
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.tir.IntImm("int32", val))
ret.append(val)
elif isinstance(dim, tvm.tir.Any):
ret.append(te.var("any_dim", "int32"))
else:
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
The created tensor
"""
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
shape = _ffi_api.GetShape(shape)
dtype = "float32" if dtype is None else dtype
return _ffi_api.Placeholder(
shape, dtype, name)
Expand Down Expand Up @@ -89,6 +90,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
shape = _ffi_api.GetShape(shape)
ndim = len(shape)
code = fcompute.__code__

Expand Down Expand Up @@ -288,6 +290,10 @@ def extern(shape,
if len(shape) != len(out_buffers):
raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
% (len(shape), len(out_buffers)))
promoted_shape = []
for shp in shape:
promoted_shape.append(_ffi_api.GetShape(shp))
shape = promoted_shape
input_placeholders = in_buffers or []
output_placeholders = out_buffers or []
types = set()
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
for (IndexExpr val : shape) {
const int64_t* pval = tir::as_const_int(val);
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(IntImm(DataType::Int(32), *pval));
// CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
// CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
// res.push_back(IntImm(DataType::Int(32), *pval));
res.push_back(val);
} else if (val->IsInstance<tir::AnyNode>()) {
res.push_back(val.as<tir::AnyNode>()->ToVar());
} else {
Expand Down
8 changes: 6 additions & 2 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
shape = GetShape(shape);
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(
IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
IterVarNode::make(Range(IntImm(shape[i].dtype(), 0), shape[i]),
Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}

Expand All @@ -111,11 +113,13 @@ Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
shape = GetShape(shape);
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(
IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
IterVarNode::make(Range(IntImm(shape[i].dtype(), 0), shape[i]),
Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}

Expand Down
49 changes: 49 additions & 0 deletions src/te/operation/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
*/
#include "op_util.h"

#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <string>
#include <limits>

#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
Expand All @@ -39,6 +42,23 @@ namespace te {
using namespace arith;
using namespace tir;

Range RangeMatchTypes(Range dom) {
PrimExpr a = dom->min;
PrimExpr b = dom->extent;
if (a.dtype() == b.dtype()) return dom;
DataType atype = a.dtype();
DataType btype = b.dtype();
// Only do int type promotion
CHECK(atype.is_scalar());
CHECK(btype.is_scalar());
CHECK(atype.code() == btype.code());
int bits = std::max(atype.bits(), btype.bits());
DataType dtype = atype.with_bits(bits);
a = cast(dtype, a);
b = cast(dtype, b);
return Range::make_by_min_extent(a, b);
}

std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos, bool new_loop_var,
Expand Down Expand Up @@ -71,6 +91,9 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
// initialize the offset and loop_level
Var var = bind_iv->var;

// Match the type of dom
dom = RangeMatchTypes(dom);

// Mark the iter var in the IR, to remember the point
if (bind_iv->thread_tag.length() == 0) {
// Only generate new loop if we're not bound to a thread.
Expand Down Expand Up @@ -277,5 +300,31 @@ tir::ForType IterVarTypeToForType(IterVarType iter_type) {
}
}

Array<PrimExpr> GetShape(Array<PrimExpr> shape) {
bool is_const = true;
int64_t size = 1;
DataType dtype;
for (auto s : shape) {
if (const IntImmNode* i = s.as<IntImmNode>()) {
size *= i->value;
} else {
is_const = false;
dtype = s.dtype();
}
}
Array<PrimExpr> ret;
if (is_const && size > std::numeric_limits<int32_t>::max()) {
for (auto s : shape) {
int64_t value = Downcast<IntImm>(s)->value;
ret.push_back(IntImm(DataType::Int(64), value));
}
} else {
ret = shape;
}
return ret;
}

TVM_REGISTER_GLOBAL("te.GetShape").set_body_typed(GetShape);

} // namespace te
} // namespace tvm
2 changes: 2 additions & 0 deletions src/te/operation/placeholder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include "op_util.h"

namespace tvm {
namespace te {
Expand Down Expand Up @@ -59,6 +60,7 @@ Operation PlaceholderOpNode::make(std::string name, Array<PrimExpr> shape, DataT
}

Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
shape = GetShape(shape);
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}

Expand Down
17 changes: 11 additions & 6 deletions src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ class VecAllocAccess : public StmtExprMutator {

class Vectorizer : public StmtExprMutator {
public:
Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = RampNode::make(0, 1, var_lanes);
Vectorizer(Var var, IntImm var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = RampNode::make(IntImm(var_lanes.dtype(), 0),
IntImm(var_lanes.dtype(), 1),
var_lanes->value);
}

Stmt VisitStmt(const Stmt& stmt) final {
Expand Down Expand Up @@ -363,7 +365,9 @@ class Vectorizer : public StmtExprMutator {
// place the vector lanes in least significant dimension.
extents.push_back(var_lanes_);
// rewrite access to buffer internally.
Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
Stmt body = VecAllocAccess(op->buffer_var.get(),
var_,
static_cast<int>(var_lanes_->value))(op->body);
body = this->VisitStmt(body);
return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
}
Expand All @@ -372,7 +376,8 @@ class Vectorizer : public StmtExprMutator {
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
return ForNode::make(idx, make_zero(var_lanes_.dtype()),
var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}

private:
Expand All @@ -381,7 +386,7 @@ class Vectorizer : public StmtExprMutator {
// variable to be replaced
Var var_;
// the lanes.
int var_lanes_;
IntImm var_lanes_;
// ramp representing the var.
PrimExpr ramp_;
// flag to mark requirment of scalarization.
Expand Down Expand Up @@ -457,7 +462,7 @@ class LoopVectorizer : public StmtMutator {
if (!extent_as_int || extent_as_int->value < 1) {
LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
}
return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
return Vectorizer(op->loop_var, GetRef<IntImm>(extent_as_int))(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
Expand Down
115 changes: 115 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import numpy as np


def test_fuse_simple():
Expand Down Expand Up @@ -621,6 +622,117 @@ def expected():
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, after)


def test_fuse_strided_slice():
"""Test fusion case involving concat and strided_slice"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)])
t = relay.Function(relay.analysis.free_vars(out), out)
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
p0 = relay.var("p0", shape=shape)
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)])

f0 = relay.Function([p0], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x])
return relay.Function([x], y)
orig = before()
fuse0(tvm.IRModule.from_expr(orig))
t = tvm.IRModule.from_expr(orig)
m = fuse2(tvm.IRModule.from_expr(orig))
attention = m["main"].body.op.params

relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_take():
"""Test fusion case involving concat and take"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape1 = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
shape2 = (tvm.tir.const(1, "int64"),)
x = relay.var("x", shape=shape1)
p0 = relay.var("p0", shape=shape1)
p1 = relay.var("p1", shape=shape2,
dtype="int64")
c = relay.const([0], dtype="int64")
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.op.take(concat, indices=p1)

f0 = relay.Function([p0, p1], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x, c])
return relay.Function([x], y)

orig = before()
fuse0(tvm.IRModule.from_expr(orig))
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_gather_nd():
"""Test fusion case involving concat and gather_nd"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape1 = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
shape2 = (tvm.tir.const(2, "int64"),
tvm.tir.const(2, "int64"))
x = relay.var("x", shape=shape1)
p0 = relay.var("p0", shape=shape1)
p1 = relay.var("p1", shape=shape2, dtype="int64")
c = relay.const([[0,1],[1,0]], dtype="int64")
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.gather_nd(concat, indices=p1)

f0 = relay.Function([p0, p1], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x, c])
return relay.Function([x], y)

orig = before()
fuse0(tvm.IRModule.from_expr(orig))
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -637,3 +749,6 @@ def expected():
test_immutable()
test_split()
test_fuse_max()
test_fuse_strided_slice()
test_fuse_take()
test_fuse_gather_nd()
Loading

0 comments on commit c79dc79

Please sign in to comment.