Skip to content

Commit

Permalink
[TF] Fix some shape mismatches between TF and Relay
Browse files Browse the repository at this point in the history
  Make ndarray_size output scalar
  Make gather_nd output scalar if needed
  • Loading branch information
lixiaoquan committed Jul 29, 2020
1 parent 2e93aef commit c898cac
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 11 deletions.
3 changes: 0 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2740,9 +2740,6 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
if (oshape.size() == 0) {
oshape.push_back(tir::make_const(DataType::Int(32), 1));
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ bool NdarraySizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
CHECK(tt != nullptr);
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[1], TensorType({1}, param->dtype));
reporter->Assign(types[1], TensorType({}, param->dtype));
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class ConstantFolder : public ExprMutator {
ctx.device_id = 0;
runtime::NDArray value;
DLDataType cdtype = DataType::Int(32);
value = runtime::NDArray::Empty({1}, cdtype, ctx);
value = runtime::NDArray::Empty({}, cdtype, ctx);
int32_t* data = static_cast<int32_t*>(value->data);
if (ishape.size() == 0) {
*data = 0;
Expand Down
4 changes: 3 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def convert_to_list(x):

def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
return [o.asnumpy()]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
Expand Down Expand Up @@ -211,6 +211,8 @@ def name_without_num(name):
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
if not isinstance(tf_output[i], np.ndarray):
assert len(tvm_output[i].shape) == 0
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def before(dtype):
def expected(dtype):
x = relay.var("x", shape=c_shape, dtype="float32")
y = relay.var("y", shape=c_shape, dtype="float32")
z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype)
z = relay.const(np.size(np.zeros(c_shape)), dtype=dtype)
func = relay.Function([x, y], z)
return func

Expand Down
5 changes: 1 addition & 4 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,9 +1126,6 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
for (size_t i = indices_dim0; i < ndim_d; ++i) {
out_shape.push_back(data->shape[i]);
}
if (out_shape.size() == 0) {
out_shape.push_back(make_const(DataType::Int(32), 1));
}
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Expand Down Expand Up @@ -1401,7 +1398,7 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
const std::string& name = "ndarray_size",
const std::string& tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
Array<PrimExpr> out_ndarray_size = {1};
Array<PrimExpr> out_ndarray_size = {};
return compute(
out_ndarray_size,
[&](const Array<Var>& indices) {
Expand Down

0 comments on commit c898cac

Please sign in to comment.