Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
tests for maxroipool, randomnormal, randomuniform
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Jan 9, 2019
1 parent 1deadd4 commit c305add
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,18 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
with self.subTest(test_name):
names, input_tensors, inputsym = get_input_tensors(inputs)
test_op = mxnet_op(*inputsym, **attrs)
mxnet_output = forward_pass(test_op, None, None, names, inputs)
outputshape = np.shape(mxnet_output)
if inputs:
test_op = mxnet_op(*inputsym, **attrs)
mxnet_output = forward_pass(test_op, None, None, names, inputs)
outputshape = np.shape(mxnet_output)
else:
test_op = mxnet_op(**attrs)
shape = attrs.get('shape', (1,))
x = mx.nd.zeros(shape, dtype='float32')
xgrad = mx.nd.zeros(shape, dtype='float32')
exe = test_op.bind(ctx=mx.cpu(), args={'x': x}, args_grad={'x': xgrad})
mxnet_output = exe.forward(is_train=False)[0].asnumpy()
outputshape = np.shape(mxnet_output)

if mxnet_specific:
onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs],
Expand Down Expand Up @@ -208,7 +217,11 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
# since results would be random, checking for shape alone
("test_multinomial", mx.sym.sample_multinomial, "Multinomial",
[np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")],
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True),
("test_random_normal", mx.sym.random_normal, "RandomNormal", [],
{'shape': (2, 2), 'loc': 0, 'scale': 1}, False, {'modify': {'loc': 'mean'}}, False, True),
("test_random_uniform", mx.sym.random_uniform, "RandomUniform", [],
{'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True)
]

if __name__ == '__main__':
Expand Down

0 comments on commit c305add

Please sign in to comment.