Skip to content

Commit

Permalink
[Relay]Improve Shape Func handling for Tuple inputs (apache#5467)
Browse files Browse the repository at this point in the history
* Improve Shape Func handling for Tuple inputs

* Fix lint

* Improve

* Fix build
  • Loading branch information
kevinthesun authored and Trevor Morris committed Jun 18, 2020
1 parent 7b11183 commit e341c4b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
7 changes: 7 additions & 0 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,13 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
return fields;
}

Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
Array<te::Tensor> input_shapes = VisitExpr(op->tuple);
Array<te::Tensor> out;
out.push_back(input_shapes[op->index]);
return out;
}

private:
/*! \brief String stream for function name */
std::ostringstream readable_name_stream_;
Expand Down
13 changes: 12 additions & 1 deletion src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,23 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
auto tuple = TupleType(func_type->arg_types);
auto in_types = FlattenTupleType(tuple);
auto out_types = FlattenTupleType(func_type->ret_type);
Array<Integer> is_input;
for (size_t i = 0; i < func_type->arg_types.size(); ++i) {
auto const& aty = func_type->arg_types[i];
size_t num_types = 1;
if (aty.as<TupleTypeNode>()) {
num_types = FlattenTupleType(aty).size();
}
for (size_t j = 0; j < num_types; ++j) {
is_input.push_back(shape_func_attrs->is_input[i]);
}
}

Array<Type> shape_func_ins, shape_func_outs;
for (size_t i = 0; i < in_types.size(); i++) {
auto in_type = in_types[i];

if (shape_func_attrs->is_input[i]) {
if (is_input[i]) {
shape_func_ins.push_back(in_type);
} else {
auto shape = RankShape(in_type->shape);
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,47 @@ def _body(i, st):
except Exception as e:
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)

def test_tuple_get_item():
mod = tvm.IRModule()
dtype = "float32"
static_data_shape = (9, 4)
data_shape = (relay.Any(), 4)
indices_or_sections = 2
axis = 1
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.split(data, indices_or_sections, axis)
y = relay.expr.TupleGetItem(y.astuple(), 0)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_out_shape = (9, 2)
for kind in ["vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape))

def test_mixed_input_type():
mod = tvm.IRModule()
dtype = "float32"
static_data_shape = (9, 4)
data_shape = (relay.Any(), 4)
tensor_type = relay.TensorType(data_shape, dtype)
tuple_type = relay.TupleType([tensor_type, tensor_type])
data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type]))
data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype)
data_tuple = relay.expr.TupleWrapper(data0, 2)
nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2)
y = nested_data_tuple[1] * data_tuple[1] + data1
mod["main"] = relay.Function([data0, data1], y)
data_np0 = np.random.uniform(size=static_data_shape).astype(dtype)
data_np1 = np.random.uniform(size=static_data_shape).astype(dtype)
ref_out_shape = (9, 4)
for kind in ["vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1)
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape))

if __name__ == "__main__":
test_any_full()
test_any_broadcast()
Expand Down Expand Up @@ -708,3 +749,6 @@ def _body(i, st):
test_arange_with_dynamic_shape()
test_recursive_concat()
test_recursive_concat_with_wrong_annotation()
test_tuple_get_item()
test_mixed_input_type()

0 comments on commit e341c4b

Please sign in to comment.