Skip to content

Commit

Permalink
Add RoiAlign to Onnx frontend (apache#5454)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart authored and Trevor Morris committed Jun 8, 2020
1 parent a5396ed commit be47eb9
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
32 changes: 32 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import vision as _vision
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
Expand Down Expand Up @@ -1495,6 +1496,34 @@ def _impl_v1(cls, inputs, attr, params):

return _op.topk(inputs[0], k=K, axis=axis)


class RoiAlign(OnnxOpConverter):
"""Operator converter for TopK
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if len(inputs) != 3:
raise ValueError("Expect 3 inputs only")
x = inputs[0]
rois = inputs[1]
batch_indices = inputs[2]
mode = attr.get("mode", "avg")
if mode != b'avg':
raise ValueError("RoiAlign in Relay only uses avg mode")
output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)

sampling_ratio = attr.get("sampling_ratio", 0)
spatial_scale = attr.get("spatial_scale", 1.0)

batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1)
batch_indices = _op.cast(
batch_indices, infer_type(rois).type_annotation.dtype)
rois = _op.concatenate([batch_indices, rois], 1)

return _vision.roi_align(x, rois, [output_height, output_width],
spatial_scale, sampling_ratio)

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1592,6 +1621,9 @@ def _get_convert_map(opset):
# Recurrent Layers
'LSTM': LSTM.get_converter(opset),

# defs/vision
'RoiAlign': RoiAlign.get_converter(opset),

# defs/reduction
'ReduceMax': ReduceMax.get_converter(opset),
'ReduceMin': ReduceMin.get_converter(opset),
Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* rois = types[1].as<TensorTypeNode>();
CHECK(data);
CHECK(rois);
const auto& dshape = data->shape;
const auto& rshape = rois->shape;
CHECK(roi_align_attrs);
Expand Down
63 changes: 63 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,68 @@ def verify_topk(input_dims, K, axis=-1):
verify_topk([n, n, n], 5, 2)


def test_roi_align():
def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0):
output_dims = [num_roi, input_dims[1], output_height, output_width]

node = helper.make_node('RoiAlign',
inputs=['X', 'rois', 'batch_indicies'],
outputs=['Y'],
mode="avg",
output_height=output_height,
output_width=output_width,
sampling_ratio=sampling_ratio,
spatial_scale=spatial_scale,
)

graph = helper.make_graph([node],
"roialign_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
helper.make_tensor_value_info(
"rois", TensorProto.FLOAT, [num_roi, 4]),
helper.make_tensor_value_info(
"batch_indicies", TensorProto.INT64, [num_roi, ]),
],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)])

model = helper.make_model(graph, producer_name='roialign_test')

np_data = np.random.uniform(size=input_dims).astype("float32")
np_rois = np.random.uniform(size=[num_roi, 4]).astype(
'float32') * input_dims[2]
np_batch_indicies = np.random.randint(
low=0, high=input_dims[0], size=num_roi)

onnx_out = get_onnxruntime_output(
model, [np_data, np_rois, np_batch_indicies])
for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims,
output_dtype='float32')
tvm.testing.assert_allclose(
onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)

verify_roi_align((1, 4, 16, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((4, 4, 16, 32), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 8, 16, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 8, 8), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 16, 5, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 12), 8, 7, 3,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=0.5)
verify_roi_align((3, 4, 12, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.5)
verify_roi_align((5, 4, 16, 14), 32, 7, 7,
sampling_ratio=1, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 32, 7, 7,
sampling_ratio=2, spatial_scale=1.0)


if __name__ == '__main__':
test_flatten()
test_reshape()
Expand Down Expand Up @@ -2498,3 +2560,4 @@ def verify_topk(input_dims, K, axis=-1):
test_resize()
test_nonzero()
test_topk()
test_roialign()

0 comments on commit be47eb9

Please sign in to comment.