diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index aa835f00671e..5d9b3d582e73 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1102,7 +1102,8 @@ Do not modify directly.* |||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| +|||10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| |STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp index c3a25ca8d464..892efca3058d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp @@ -29,14 +29,25 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp {"max", DML_REDUCE_FUNCTION_MAX}, {"avg", DML_REDUCE_FUNCTION_AVERAGE}, }; + + constexpr NameAndIndex coordinateTransformationModes[] = + { + {"half_pixel", 0}, + {"output_half_pixel", 1}, + }; + + std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute(AttrName::CoordinateTransformationMode, "half_pixel"); + auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes); const std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "avg"); const auto optionalReductionFunction = TryMapStringToIndex(mode, mapping); const float spatialScale = kernelCreationContext.GetOptionalAttribute(AttrName::SpatialScale, 1.0f); const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute(AttrName::SamplingRatio, 0u); ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive."); ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode."); + ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode."); + - DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {}; + DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; operatorDesc.ROITensor = &inputDescs[1]; operatorDesc.BatchIndicesTensor = &inputDescs[2]; @@ -48,12 +59,15 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput; operatorDesc.ReductionFunction = *optionalReductionFunction; operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc }; + operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f; + operatorDesc.OutputPixelOffset = -0.5f; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); } }; DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index eea19ad8d883..34a8f496d7d9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool); DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool); DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10); +DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16); DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization); DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization); DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15); @@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)}, {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. @@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16 {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, - {REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16 {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 837211339384..0206e483d92b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1449,6 +1449,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper; using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper; using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper; +using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper; using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 5e2ca4cb116e..054faad5ba8f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -381,6 +381,7 @@ namespace OperatorHelper static const int sc_sinceVer_LessOrEqual = 16; static const int sc_sinceVer_ScatterND = 16; static const int sc_sinceVer_ScatterElements = 16; + static const int sc_sinceVer_RoiAlign = 16; } // namespace OnnxOperatorSet16 namespace OnnxOperatorSet17 diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 3ee2f4ddcf91..a00e3bc4f2b6 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); {"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."}, {"test_roialign_aligned_true", "Opset 16 not supported yet."}, {"test_roialign_aligned_false", "Opset 16 not supported yet."}, + {"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."}, {"test_scatternd_add", "Opset 16 not supported yet."}, {"test_scatternd_multiply", "Opset 16 not supported yet."}, {"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."}, diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc index a7cc7c536a19..ad9c561ffb51 100644 --- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc @@ -9,11 +9,10 @@ namespace onnxruntime { namespace test { TEST(RoiAlignTest, AvgModePositive) { - // TODO: Unskip when fixed #41968513 + // TODO: Unskip when fixed ort issue #3428 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold"; } - OpTester test("RoiAlign", 10); test.AddAttribute("output_height", 3); test.AddAttribute("output_width", 4); @@ -30,7 +29,241 @@ TEST(RoiAlignTest, AvgModePositive) { test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f}); + // As per ORT issue https://github.com/microsoft/onnxruntime/issues/6921, the above output values are INCORRECT. + // DML has the correct outputs, which are defined below. + /*test.AddOutput("Y", {5, 3, 3, 4}, { + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + });*/ + test.Run(); +} + +TEST(RoiAlignTest, AvgModePositive_half_pixel) { + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f / 16.0f); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 5; + constexpr int W = 5; + + std::vector rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.}; + test.AddInput("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.}); + test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); + test.AddOutput("Y", {5, 3, 3, 4}, {0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000}); + test.Run(); +} + +TEST(RoiAlignTest, AvgModePositive_output_half_pixel) { + // TODO: Unskip when fixed ort issue #3428 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.95832991600036621, which exceeds threshold"; + } + + OpTester test("RoiAlign", 16); + test.AddAttribute("output_height", 3); + test.AddAttribute("output_width", 4); + test.AddAttribute("sampling_ratio", 2); + test.AddAttribute("spatial_scale", 1.0f / 16.0f); + test.AddAttribute("coordinate_transformation_mode", "output_half_pixel"); + + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 5; + constexpr int W = 5; + std::vector rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.}; + test.AddInput("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.}); + test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); + test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); + test.AddOutput("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f}); test.Run(); } @@ -230,12 +463,11 @@ static void BasicTest() { 0.3661f, 0.2349f, }); - test.Run(); } TEST(RoiAlignTest, OnnxTest) { - // TODO: Unskip when fixed #41968513 + // TODO: Unskip when fixed ort issue #3428 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.051382988691329956, which exceeds threshold"; } @@ -245,7 +477,7 @@ TEST(RoiAlignTest, OnnxTest) { } TEST(RoiAlignTest, MaxModePositive) { - // TODO: Unskip when fixed #41968513 + // TODO: Unskip when fixed ort issue #3428 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.1093800067901611, which exceeds threshold"; } @@ -267,10 +499,196 @@ TEST(RoiAlignTest, MaxModePositive) { test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.}); test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0}); test.AddOutput("Y", {5, 3, 3, 4}, {2.10938f, 2.95313f, 3.375f, 2.53125f, 3.35938f, 4.70313f, 5.375f, 4.03125f, 3.51563f, 4.92188f, 5.625f, 4.21875f, 10.8984f, 15.2578f, 17.4375f, 13.0781f, 17.3568f, 24.2995f, 27.7708f, 20.8281f, 18.1641f, 25.4297f, 29.0625f, 21.7969f, 19.6875f, 27.5625f, 31.5f, 23.625f, 31.3542f, 43.8958f, 50.1667f, 37.625f, 32.8125f, 45.9375f, 52.5f, 39.375f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 5.625f, 5.625f, 5.625f, 4.57031f, 8.95833f, 8.95833f, 8.95833f, 7.27865f, 9.375f, 9.375f, 9.375f, 7.61719f, 19.6875f, 19.6875f, 19.6875f, 15.9961f, 31.3542f, 31.3542f, 31.3542f, 25.4753f, 32.8125f, 32.8125f, 32.8125f, 26.6602f, 33.75f, 33.75f, 33.75f, 27.4219f, 53.75f, 53.75f, 53.75f, 43.6719f, 56.25f, 56.25f, 56.25f, 45.7031f, 4.5f, 3.9375f, 2.8125f, 3.9375f, 5.5f, 4.8125f, 3.4375f, 4.8125f, 4.58333f, 4.01042f, 2.86458f, 3.9375f, 23.25f, 20.3438f, 14.5313f, 18.f, 28.4167f, 24.86458f, 17.76042f, 22.f, 23.25f, 20.3437f, 14.5312f, 18.f, 42.f, 36.75f, 26.25f, 32.0625f, 51.3333f, 44.9167f, 32.08333f, 39.1875f, 42.f, 36.75f, 26.25f, 32.0625f, 4.375f, 4.375f, 4.375f, 4.375f, 7.70833f, 7.70833f, 7.70833f, 7.70833f, 9.375f, 9.375f, 9.375f, 9.375f, 21.875f, 21.875f, 21.875f, 21.875f, 26.9792f, 26.9792f, 26.9792f, 26.9792f, 32.8125f, 32.8125f, 32.8125f, 32.8125f, 40.1042f, 40.1042f, 40.1042f, 40.1042f, 46.25f, 46.25f, 46.25f, 46.25f, 56.25f, 56.25f, 56.25f, 56.25f}); - + // As per ort issue #3428, the above output values are INCORRECT. + // DML has the correct outputs, which are defined below. + /*test.AddOutput("Y",{5, 3, 3, 4}, { + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 27.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + 52.0000, + + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 0.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 25.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + 50.0000, + + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 6.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 31.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + 56.5625, + + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 3.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 28.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + 53.3125, + + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 5.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 30.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + 55.9375, + });*/ test.Run(); } - TEST(RoiAlignTest, AvgModeNegativeInvalidMode) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 34e2087f1d15..cd9a90ee8ecf 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -105,6 +105,7 @@ "^test_col2im_pads*", // remove this when using ONNX with this: https://github.com/onnx/onnx/pull/4769 // Following tests are for opset 16 ops and are not yet implemented in ORT "^test_roialign_aligned_*", + "^test_roialign_mode_max", // TODO: Remove once onnx test is fixed //GPU failures "^test_batchnorm_epsilon_training_mode_cuda", "^test_batchnorm_example_training_mode_cuda",