Skip to content

Commit

Permalink
Add ability to have multiple copies of same input to onnx_inputs. (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm authored and dpankratz committed Apr 24, 2020
1 parent e132a55 commit 69a46a9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
3 changes: 1 addition & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __setitem__(self, item, value):
if isinstance(item, int):
self.input_dict[self.input_keys[item]] = value
elif isinstance(item, str):
if item not in self.input_dict:
self.input_keys.append(item)
self.input_keys.append(item)
self.input_dict[item] = value
else:
raise ValueError("Only integer and string indexed writes allowed.")
Expand Down
11 changes: 6 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,16 +1366,16 @@ def test_binary_ops():
dtype = "float32"
out_shape = in_shape

def verify_binary_ops(op, x, y, out_np, broadcast=None):
def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=None):
if broadcast is None:
z = helper.make_node(op, ['in1', 'in2'], ['out'])
z = helper.make_node(op, [x_name, y_name], ['out'])
else:
z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
z = helper.make_node(op, [x_name, y_name], ['out'], broadcast=1)
graph = helper.make_graph([z],
'_test',
inputs=[helper.make_tensor_value_info("in1",
inputs=[helper.make_tensor_value_info(x_name,
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in2",
helper.make_tensor_value_info(y_name,
TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
Expand All @@ -1393,6 +1393,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None):
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul", x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Mul", x, x, x * x, x_name='in1', y_name='in1', broadcast=None)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
verify_binary_ops("Sum", x, y, x + y, broadcast=None)
Expand Down

0 comments on commit 69a46a9

Please sign in to comment.