From f6726391075385afa0273afa9b41c371d483ac28 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 12 Jun 2020 08:39:24 -0700 Subject: [PATCH] Add ignore storage_order attribute to onnx pooling parser. (#5781) --- python/tvm/relay/frontend/onnx.py | 2 +- tests/python/frontend/onnx/test_forward.py | 24 ++++++++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 42f28d4ba8e7..17cb148ec999 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -272,7 +272,7 @@ def _impl_v1(cls, inputs, attr, params): 'kernel_shape': 'pool_size', 'pads': ('padding', 0) }, - ignores=['dilations'], + ignores=['dilations', 'storage_order'], custom_check=dimension_constraint())(inputs, attr, params) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 178f059e2635..665cb7bffd4f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2187,20 +2187,18 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p else: raise ValueError("Pool method {} is not supported.".format(mode)) + pool_node = helper.make_node( + node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides) + if pads is None: - pool_node = helper.make_node(node_type, - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - auto_pad=auto_pad, - strides=strides) + pad_attr = helper.make_attribute('auto_pad', auto_pad) else: - pool_node = helper.make_node(node_type, - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - pads=pads, - strides=strides) + pad_attr = helper.make_attribute('pads', pads) + pool_node.attribute.append(pad_attr) + + if mode == 'max': + storage_attr = helper.make_attribute('storage_order', 0) + pool_node.attribute.append(storage_attr) graph = helper.make_graph([pool_node], "pooling_test", @@ -2907,4 +2905,4 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_mod() test_xor() test_max_roi_pool() - test_roialign() + test_roi_align()