From c305addba5165d9fefa54213f5694822930884aa Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 27 Dec 2018 13:21:53 -0800 Subject: [PATCH] tests for maxroipool, randomnormal, randomuniform --- tests/python-pytest/onnx/test_node.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 6a0f8bcd73c2..bb79ded7596c 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -139,9 +139,18 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test with self.subTest(test_name): names, input_tensors, inputsym = get_input_tensors(inputs) - test_op = mxnet_op(*inputsym, **attrs) - mxnet_output = forward_pass(test_op, None, None, names, inputs) - outputshape = np.shape(mxnet_output) + if inputs: + test_op = mxnet_op(*inputsym, **attrs) + mxnet_output = forward_pass(test_op, None, None, names, inputs) + outputshape = np.shape(mxnet_output) + else: + test_op = mxnet_op(**attrs) + shape = attrs.get('shape', (1,)) + x = mx.nd.zeros(shape, dtype='float32') + xgrad = mx.nd.zeros(shape, dtype='float32') + exe = test_op.bind(ctx=mx.cpu(), args={'x': x}, args_grad={'x': xgrad}) + mxnet_output = exe.forward(is_train=False)[0].asnumpy() + outputshape = np.shape(mxnet_output) if mxnet_specific: onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs], @@ -208,7 +217,11 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att # since results would be random, checking for shape alone ("test_multinomial", mx.sym.sample_multinomial, "Multinomial", [np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")], - {'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True) + {'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True), + ("test_random_normal", mx.sym.random_normal, "RandomNormal", [], + {'shape': (2, 2), 'loc': 0, 'scale': 1}, False, {'modify': {'loc': 'mean'}}, False, True), + ("test_random_uniform", mx.sym.random_uniform, "RandomUniform", [], + {'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True) ] if __name__ == '__main__':