diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 778240050898..2ef94505b413 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -26,6 +26,7 @@ from .. import expr as _expr from .. import function as _function from .. import op as _op +from .. import vision as _vision from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import infer_type, infer_value, infer_value_simulated, get_name @@ -1495,6 +1496,34 @@ def _impl_v1(cls, inputs, attr, params): return _op.topk(inputs[0], k=K, axis=axis) + +class RoiAlign(OnnxOpConverter): + """Operator converter for TopK + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 3: + raise ValueError("Expect 3 inputs only") + x = inputs[0] + rois = inputs[1] + batch_indices = inputs[2] + mode = attr.get("mode", "avg") + if mode != b'avg': + raise ValueError("RoiAlign in Relay only uses avg mode") + output_height = attr.get("output_height", 1) + output_width = attr.get("output_width", 1) + + sampling_ratio = attr.get("sampling_ratio", 0) + spatial_scale = attr.get("spatial_scale", 1.0) + + batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1) + batch_indices = _op.cast( + batch_indices, infer_type(rois).type_annotation.dtype) + rois = _op.concatenate([batch_indices, rois], 1) + + return _vision.roi_align(x, rois, [output_height, output_width], + spatial_scale, sampling_ratio) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1592,6 +1621,9 @@ def _get_convert_map(opset): # Recurrent Layers 'LSTM': LSTM.get_converter(opset), + # defs/vision + 'RoiAlign': RoiAlign.get_converter(opset), + # defs/reduction 'ReduceMax': ReduceMax.get_converter(opset), 'ReduceMin': ReduceMin.get_converter(opset), diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 6b221a279bac..5661ebb74ac4 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -36,6 +36,8 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* rois = types[1].as(); + CHECK(data); + CHECK(rois); const auto& dshape = data->shape; const auto& rshape = rois->shape; CHECK(roi_align_attrs); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f33c5f9ab9b0..1185a5c6bf5d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2432,6 +2432,68 @@ def verify_topk(input_dims, K, axis=-1): verify_topk([n, n, n], 5, 2) +def test_roi_align(): + def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0): + output_dims = [num_roi, input_dims[1], output_height, output_width] + + node = helper.make_node('RoiAlign', + inputs=['X', 'rois', 'batch_indicies'], + outputs=['Y'], + mode="avg", + output_height=output_height, + output_width=output_width, + sampling_ratio=sampling_ratio, + spatial_scale=spatial_scale, + ) + + graph = helper.make_graph([node], + "roialign_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), + helper.make_tensor_value_info( + "rois", TensorProto.FLOAT, [num_roi, 4]), + helper.make_tensor_value_info( + "batch_indicies", TensorProto.INT64, [num_roi, ]), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)]) + + model = helper.make_model(graph, producer_name='roialign_test') + + np_data = np.random.uniform(size=input_dims).astype("float32") + np_rois = np.random.uniform(size=[num_roi, 4]).astype( + 'float32') * input_dims[2] + np_batch_indicies = np.random.randint( + low=0, high=input_dims[0], size=num_roi) + + onnx_out = get_onnxruntime_output( + model, [np_data, np_rois, np_batch_indicies]) + for target, ctx in [('llvm', tvm.cpu())]: + tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims, + output_dtype='float32') + tvm.testing.assert_allclose( + onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + + verify_roi_align((1, 4, 16, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((4, 4, 16, 32), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 8, 16, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 8, 8), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 16, 5, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 12), 8, 7, 3, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=0.5) + verify_roi_align((3, 4, 12, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.5) + verify_roi_align((5, 4, 16, 14), 32, 7, 7, + sampling_ratio=1, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 32, 7, 7, + sampling_ratio=2, spatial_scale=1.0) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -2498,3 +2560,4 @@ def verify_topk(input_dims, K, axis=-1): test_resize() test_nonzero() test_topk() + test_roialign()