Skip to content

Commit

Permalink
[relay][op] Add shape func to tile (#4441)
Browse files Browse the repository at this point in the history
* [relay][op] Add shape func to tile

* retrigger ci

* check dynamic axes

* retrigger ci
  • Loading branch information
zhiics authored and kevinthesun committed Dec 5, 2019
1 parent 6ef2418 commit ba9d96b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 15 deletions.
32 changes: 32 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,35 @@ def reshape_like_shape_func(attrs, inputs, _):
Shape function for reshape_like op.
"""
return [_reshape_like_shape_func(inputs[1])]

@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")

if ndim == rndim:
for i in const_range(tndim):
out[i] = data[i] * int64(reps[i])
elif ndim > rndim:
ngap = ndim - rndim
for i in const_range(ndim):
if i < ngap:
out[i] = data[i]
else:
out[i] = data[i] * int64(reps[i - ngap])
else:
rgap = rndim - ndim
for i in const_range(rndim):
if i < rgap:
out[i] = int64(reps[i])
else:
out[i] = int64(reps[i]) * data[i - rgap]
return out

@_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _):
reps = get_const_tuple(attrs.reps)
ndim = inputs[0].shape[0].value
rndim = len(reps)
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
convert(tndim), convert(rndim))]
41 changes: 26 additions & 15 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1393,28 +1393,39 @@ bool TileRel(const Array<Type>& types,
reps_shape.reserve(tndim);
if (ndim == rndim) {
for (size_t i = 0; i < tndim; ++i) {
data_shape.emplace_back(data->shape[i]);
reps_shape.emplace_back(reps[i]);
data_shape.emplace_back(data->shape[i]);
reps_shape.emplace_back(reps[i]);
}
} else if (ndim > rndim) {
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
for (size_t i = 0; i < (ndim - rndim); ++i)
reps_shape.emplace_back(1);
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < ndim; ++i) {
data_shape.emplace_back(data->shape[i]);
}
for (size_t i = 0; i < (ndim - rndim); ++i) {
reps_shape.emplace_back(1);
}
for (size_t i = 0; i < rndim; ++i) {
reps_shape.emplace_back(reps[i]);
}
} else {
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < (rndim - ndim); ++i)
data_shape.emplace_back(1);
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
for (size_t i = 0; i < rndim; ++i) {
reps_shape.emplace_back(reps[i]);
}
for (size_t i = 0; i < (rndim - ndim); ++i) {
data_shape.emplace_back(1);
}
for (size_t i = 0; i < ndim; ++i) {
data_shape.emplace_back(data->shape[i]);
}
}
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
oshape.emplace_back(data_shape[i] * reps_shape[i]);
// Save Any if it is dynamic shape
if (!data_shape[i].as<IntImm>()) {
oshape.emplace_back(Any::make());
} else {
oshape.emplace_back(data_shape[i] * reps_shape[i]);
}
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,25 @@ def test_any_take():
verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))

def verify_any_tile(dshape, reps, np_dshape, np_reps):
mod = relay.Module()
x = relay.var("x", shape=dshape, dtype="float32")
y = relay.tile(x, reps=reps)
mod["main"] = relay.Function([x], y)
x_data = np.random.uniform(size=np_dshape).astype("float32")
ref_res = np.tile(x_data, reps=np_reps)

for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
res = ex.evaluate()(x_data)
tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5)

def test_any_tile():
verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1))
verify_any_tile(any_dims(3), (1, 2), (2, 3, 4), (1, 2))
verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1))
verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,))

def test_any_shape_of():
x = relay.var('x', shape=any_dims(2), dtype='float32')
y = relay.shape_of(x)
Expand Down Expand Up @@ -586,6 +605,7 @@ def _body(i, st):
test_any_concat()
test_any_reshape()
test_any_take()
test_any_tile()
test_any_shape_of()
test_any_reduce()
test_any_layout_transform()
Expand Down

0 comments on commit ba9d96b

Please sign in to comment.