Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: MaxRoiPool
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 21, 2018
1 parent 72cdfb7 commit 8d55166
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
21 changes: 21 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,3 +1718,24 @@ def convert_random_normal(node, **kwargs):
name=name
)
return [node]


@mx_op.register("ROIPooling")
def convert_roipooling(node, **kwargs):
"""Map MXNet's ROIPooling operator attributes to onnx's MaxRoiPool
operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

pooled_shape = convert_string_to_list(attrs.get('pooled_size'))
scale = float(attrs.get("spatial_scale"))

node = onnx.helper.make_node(
'MaxRoiPool',
input_nodes,
[name],
pooled_shape=pooled_shape,
spatial_scale=scale,
name=name
)
return [node]
37 changes: 35 additions & 2 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def forward_pass(sym, arg, aux, data_names, input_data):
data_forward.append(mx.nd.array(val))

mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)
if not arg and not aux:
mod.init_params()
else:
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)

# run inference
batch = namedtuple('Batch', ['data'])
Expand Down Expand Up @@ -425,6 +428,36 @@ def test_random_normal():
assert output[0].shape == shape


@with_seed()
def testRoiPooling():
x = [[np.random.randint(1, 100, (8, 6))]]
y = [[0,0,0,4,4]]
pooled_size = (2,2)
spatial_scale = 0.7
sym = mx.sym.ROIPooling(mx.sym.Variable('x'), mx.sym.Variable('y'),
pooled_size=pooled_size, spatial_scale=spatial_scale)
roipool_output = forward_pass(sym, None, None, ['x', 'y'], [x, y])

inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, shape=np.shape(x)),
helper.make_tensor_value_info("y", TensorProto.FLOAT, shape=np.shape(y))]

output_tensor = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(roipool_output))]
roipool_node = [helper.make_node('MaxRoiPool', ["x", "y"], ["output"],
pooled_shape=pooled_size, spatial_scale=spatial_scale)]

roipool_graph = helper.make_graph(roipool_node,
"roipool_test",
inputs,
output_tensor)

roipool_model = helper.make_model(roipool_graph)

bkd_rep = backend.prepare(roipool_model)
output = bkd_rep.run([x, y])

npt.assert_almost_equal(output[0], roipool_output)


def _assert_sym_equal(lhs, rhs):
assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical
assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical
Expand Down

0 comments on commit 8d55166

Please sign in to comment.