Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Refactor where importer to support dynamic sh…
Browse files Browse the repository at this point in the history
…apes. (apache#7394)

* Refactor where importer to support dynamic shapes.

* Add a test for dynamic where.
  • Loading branch information
jwfromm authored and trevor-m committed Mar 2, 2021
1 parent 302511f commit cd8065b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
48 changes: 20 additions & 28 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,34 +1560,26 @@ class Where(OnnxOpConverter):

@classmethod
def _impl_v9(cls, inputs, attr, params):
condition_shape = infer_shape(inputs[0])
x_shape = infer_shape(inputs[1])
y_shape = infer_shape(inputs[2])

# condition, x, and y can all be broadcasted.
# broadcast each of them to the longest shape.
# if two shapes have the same number of dimensions,
# try to choose the one that doesn't have "1" as
# a dimension.
shapes = [condition_shape, x_shape, y_shape]
shape_lens = [len(shape) for shape in shapes]
max_size = max(shape_lens)
max_size_idxs = [i for i, x in enumerate(shape_lens) if x == max_size]
broadcast_idx = max_size_idxs[0]
if len(max_size_idxs) > 1:
for idx in max_size_idxs:
if 1 not in shapes[idx]:
broadcast_idx = idx

broadcast_shape = shapes[broadcast_idx]

if condition_shape != broadcast_shape:
inputs[0] = _op.broadcast_to(inputs[0], broadcast_shape)
if x_shape != broadcast_shape:
inputs[1] = _op.broadcast_to(inputs[1], broadcast_shape)
if y_shape != broadcast_shape:
inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape)
return _op.where(inputs[0], inputs[1], inputs[2])
condition_rank = len(infer_shape(inputs[0]))
x_rank = len(infer_shape(inputs[1]))
y_rank = len(infer_shape(inputs[2]))
ranks = [condition_rank, x_rank, y_rank]

# If one rank is longer than others, then we can broadcast
# to that shape.
max_rank = max(ranks)
max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank]
broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]])
# If two or more inputs have the same rank, compute the broadcast
# shape by taking the maximum value of each dimensions.
if len(max_rank_idxs) > 1:
for idx in max_rank_idxs:
broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx]))

condition = _op.broadcast_to(inputs[0], broadcast_shape)
x = _op.broadcast_to(inputs[1], broadcast_shape)
y = _op.broadcast_to(inputs[2], broadcast_shape)
return _op.where(condition, x, y)


class Or(Elemwise):
Expand Down
17 changes: 13 additions & 4 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,10 +2107,18 @@ def test_erf():
verify_erf(x, z)


def verify_where(condition, x, y, dtype, outdata):
node = helper.make_node("Where", inputs=["condition", "x", "y"], outputs=["out"])
def verify_where(condition, x, y, dtype, outdata, dynamic=False):
node_list = []
where_inputs = ["condition", "x", "y"]
if dynamic:
shape_node = helper.make_node("Shape", ["x"], ["shape"])
reshape_node = helper.make_node("Reshape", ["x", "shape"], ["X"])
where_inputs[1] = "X"
node_list += [shape_node, reshape_node]
node = helper.make_node("Where", inputs=where_inputs, outputs=["out"])
node_list.append(node)
graph = helper.make_graph(
[node],
node_list,
"where_test",
inputs=[
helper.make_tensor_value_info("condition", TensorProto.BOOL, list(condition.shape)),
Expand All @@ -2120,7 +2128,7 @@ def verify_where(condition, x, y, dtype, outdata):
outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))],
)
model = helper.make_model(graph, producer_name="where_test")
verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape])
verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape], use_vm=True)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -2156,6 +2164,7 @@ def test_where():
y = np.array([[1], [7]], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
verify_where(condition, x, y, TensorProto.FLOAT, outdata, dynamic=True)


def verify_or(indata, dtype):
Expand Down

0 comments on commit cd8065b

Please sign in to comment.