diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index aa835f00671e..5d9b3d582e73 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1102,7 +1102,8 @@ Do not modify directly.*
|||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
+|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
+|||10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
|ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp
index c3a25ca8d464..892efca3058d 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp
@@ -29,14 +29,25 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
{"max", DML_REDUCE_FUNCTION_MAX},
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
};
+
+ constexpr NameAndIndex coordinateTransformationModes[] =
+ {
+ {"half_pixel", 0},
+ {"output_half_pixel", 1},
+ };
+
+ std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute(AttrName::CoordinateTransformationMode, "half_pixel");
+ auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
const std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "avg");
const auto optionalReductionFunction = TryMapStringToIndex(mode, mapping);
const float spatialScale = kernelCreationContext.GetOptionalAttribute(AttrName::SpatialScale, 1.0f);
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute(AttrName::SamplingRatio, 0u);
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
+ ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode.");
+
- DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
+ DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ROITensor = &inputDescs[1];
operatorDesc.BatchIndicesTensor = &inputDescs[2];
@@ -48,12 +59,15 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = *optionalReductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
- DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
+ operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f;
+ operatorDesc.OutputPixelOffset = -0.5f;
+ DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel);
+DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index eea19ad8d883..34a8f496d7d9 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
+DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
@@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
+ {REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
@@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
- {REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
- {REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
+ {REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 837211339384..0206e483d92b 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1449,6 +1449,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper;
+using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 5e2ca4cb116e..054faad5ba8f 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -381,6 +381,7 @@ namespace OperatorHelper
static const int sc_sinceVer_LessOrEqual = 16;
static const int sc_sinceVer_ScatterND = 16;
static const int sc_sinceVer_ScatterElements = 16;
+ static const int sc_sinceVer_RoiAlign = 16;
} // namespace OnnxOperatorSet16
namespace OnnxOperatorSet17
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index 3ee2f4ddcf91..a00e3bc4f2b6 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"test_roialign_aligned_true", "Opset 16 not supported yet."},
{"test_roialign_aligned_false", "Opset 16 not supported yet."},
+ {"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."},
{"test_scatternd_add", "Opset 16 not supported yet."},
{"test_scatternd_multiply", "Opset 16 not supported yet."},
{"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."},
diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
index a7cc7c536a19..ad9c561ffb51 100644
--- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
+++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
@@ -9,11 +9,10 @@ namespace onnxruntime {
namespace test {
TEST(RoiAlignTest, AvgModePositive) {
- // TODO: Unskip when fixed #41968513
+ // TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold";
}
-
OpTester test("RoiAlign", 10);
test.AddAttribute("output_height", 3);
test.AddAttribute("output_width", 4);
@@ -30,7 +29,241 @@ TEST(RoiAlignTest, AvgModePositive) {
test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
+ // As per ORT issue https://github.com/microsoft/onnxruntime/issues/6921, the above output values are INCORRECT.
+ // DML has the correct outputs, which are defined below.
+ /*test.AddOutput("Y", {5, 3, 3, 4}, {
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ });*/
+ test.Run();
+}
+
+TEST(RoiAlignTest, AvgModePositive_half_pixel) {
+ OpTester test("RoiAlign", 16);
+ test.AddAttribute("output_height", 3);
+ test.AddAttribute("output_width", 4);
+ test.AddAttribute("sampling_ratio", 2);
+ test.AddAttribute("spatial_scale", 1.0f / 16.0f);
+ test.AddAttribute("coordinate_transformation_mode", "half_pixel");
+
+ constexpr int N = 1;
+ constexpr int C = 3;
+ constexpr int H = 5;
+ constexpr int W = 5;
+
+ std::vector rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
+ test.AddInput("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
+ test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
+ test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0});
+ test.AddOutput("Y", {5, 3, 3, 4}, {0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 25.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 50.0000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 0.312500000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 25.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 50.3125000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000000});
+ test.Run();
+}
+
+TEST(RoiAlignTest, AvgModePositive_output_half_pixel) {
+ // TODO: Unskip when fixed ort issue #3428
+ if (DefaultDmlExecutionProvider().get() != nullptr) {
+ GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.95832991600036621, which exceeds threshold";
+ }
+
+ OpTester test("RoiAlign", 16);
+ test.AddAttribute("output_height", 3);
+ test.AddAttribute("output_width", 4);
+ test.AddAttribute("sampling_ratio", 2);
+ test.AddAttribute("spatial_scale", 1.0f / 16.0f);
+ test.AddAttribute("coordinate_transformation_mode", "output_half_pixel");
+
+ constexpr int N = 1;
+ constexpr int C = 3;
+ constexpr int H = 5;
+ constexpr int W = 5;
+ std::vector rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
+ test.AddInput("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
+ test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
+ test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0});
+ test.AddOutput("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
test.Run();
}
@@ -230,12 +463,11 @@ static void BasicTest() {
0.3661f,
0.2349f,
});
-
test.Run();
}
TEST(RoiAlignTest, OnnxTest) {
- // TODO: Unskip when fixed #41968513
+ // TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.051382988691329956, which exceeds threshold";
}
@@ -245,7 +477,7 @@ TEST(RoiAlignTest, OnnxTest) {
}
TEST(RoiAlignTest, MaxModePositive) {
- // TODO: Unskip when fixed #41968513
+ // TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.1093800067901611, which exceeds threshold";
}
@@ -267,10 +499,196 @@ TEST(RoiAlignTest, MaxModePositive) {
test.AddInput("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput("Y", {5, 3, 3, 4}, {2.10938f, 2.95313f, 3.375f, 2.53125f, 3.35938f, 4.70313f, 5.375f, 4.03125f, 3.51563f, 4.92188f, 5.625f, 4.21875f, 10.8984f, 15.2578f, 17.4375f, 13.0781f, 17.3568f, 24.2995f, 27.7708f, 20.8281f, 18.1641f, 25.4297f, 29.0625f, 21.7969f, 19.6875f, 27.5625f, 31.5f, 23.625f, 31.3542f, 43.8958f, 50.1667f, 37.625f, 32.8125f, 45.9375f, 52.5f, 39.375f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 5.625f, 5.625f, 5.625f, 4.57031f, 8.95833f, 8.95833f, 8.95833f, 7.27865f, 9.375f, 9.375f, 9.375f, 7.61719f, 19.6875f, 19.6875f, 19.6875f, 15.9961f, 31.3542f, 31.3542f, 31.3542f, 25.4753f, 32.8125f, 32.8125f, 32.8125f, 26.6602f, 33.75f, 33.75f, 33.75f, 27.4219f, 53.75f, 53.75f, 53.75f, 43.6719f, 56.25f, 56.25f, 56.25f, 45.7031f, 4.5f, 3.9375f, 2.8125f, 3.9375f, 5.5f, 4.8125f, 3.4375f, 4.8125f, 4.58333f, 4.01042f, 2.86458f, 3.9375f, 23.25f, 20.3438f, 14.5313f, 18.f, 28.4167f, 24.86458f, 17.76042f, 22.f, 23.25f, 20.3437f, 14.5312f, 18.f, 42.f, 36.75f, 26.25f, 32.0625f, 51.3333f, 44.9167f, 32.08333f, 39.1875f, 42.f, 36.75f, 26.25f, 32.0625f, 4.375f, 4.375f, 4.375f, 4.375f, 7.70833f, 7.70833f, 7.70833f, 7.70833f, 9.375f, 9.375f, 9.375f, 9.375f, 21.875f, 21.875f, 21.875f, 21.875f, 26.9792f, 26.9792f, 26.9792f, 26.9792f, 32.8125f, 32.8125f, 32.8125f, 32.8125f, 40.1042f, 40.1042f, 40.1042f, 40.1042f, 46.25f, 46.25f, 46.25f, 46.25f, 56.25f, 56.25f, 56.25f, 56.25f});
-
+ // As per ort issue #3428, the above output values are INCORRECT.
+ // DML has the correct outputs, which are defined below.
+ /*test.AddOutput("Y",{5, 3, 3, 4}, {
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 2.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 27.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+ 52.0000,
+
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 25.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+ 50.0000,
+
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 6.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 31.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+ 56.5625,
+
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 3.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 28.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+ 53.3125,
+
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 5.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 30.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ 55.9375,
+ });*/
test.Run();
}
-
TEST(RoiAlignTest, AvgModeNegativeInvalidMode) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index 34e2087f1d15..cd9a90ee8ecf 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -105,6 +105,7 @@
"^test_col2im_pads*", // remove this when using ONNX with this: https://github.com/onnx/onnx/pull/4769
// Following tests are for opset 16 ops and are not yet implemented in ORT
"^test_roialign_aligned_*",
+ "^test_roialign_mode_max", // TODO: Remove once onnx test is fixed
//GPU failures
"^test_batchnorm_epsilon_training_mode_cuda",
"^test_batchnorm_example_training_mode_cuda",