Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DML EP RoiAlign-16 #15812

Merged
merged 16 commits into from
May 10, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");

DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ROITensor = &inputDescs[1];
operatorDesc.BatchIndicesTensor = &inputDescs[2];
Expand All @@ -48,12 +48,15 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = *optionalReductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
operatorDesc.InputPixelOffset = m_coordinateTransformMode == "half_pixel" ? 0.5f : 0.0f;
linnealovespie marked this conversation as resolved.
Show resolved Hide resolved
operatorDesc.OutputPixelOffset = -0.5f;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc };

SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel<DmlOperatorRegionOfInterestAlign, 16>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
Expand Down Expand Up @@ -550,6 +551,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
Expand Down Expand Up @@ -806,10 +808,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1146,9 +1146,16 @@ class RoiAlignHelper : public RoiPoolingHelperBase
{
m_outputSizeW = info.template GetOptionalAttribute<uint32_t>(AttrName::OutputWidth, 1);
m_outputSizeH = info.template GetOptionalAttribute<uint32_t>(AttrName::OutputHeight, 1);
m_coordinateTransformMode = "half_pixel";
if(opsetVersion >= 16){
fdwr marked this conversation as resolved.
Show resolved Hide resolved
m_coordinateTransformMode = info.GetAttribute(AttrName::CoordinateTransformationMode);
}
}

std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;

protected:
std::string m_coordinateTransformMode;
linnealovespie marked this conversation as resolved.
Show resolved Hide resolved
};

class SqueezeHelper
Expand Down Expand Up @@ -1449,6 +1456,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ namespace OperatorHelper
static const int sc_sinceVer_LessOrEqual = 16;
static const int sc_sinceVer_ScatterND = 16;
static const int sc_sinceVer_ScatterElements = 16;
static const int sc_sinceVer_RoiAlign = 16;
} // namespace OnnxOperatorSet16

namespace OnnxOperatorSet17
Expand Down
60 changes: 57 additions & 3 deletions onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace onnxruntime {
namespace test {

TEST(RoiAlignTest, AvgModePositive) {
// TODO: Unskip when fixed #41968513
// TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold";
}
Expand All @@ -34,6 +34,60 @@ TEST(RoiAlignTest, AvgModePositive) {
test.Run();
}

TEST(RoiAlignTest, AvgModePositive_half_pixel) {
// TODO: Unskip when fixed ort issue #3428
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.9583299160003662, which exceeds threshold";
}

OpTester test("RoiAlign", 16);
test.AddAttribute<int64_t>("output_height", 3);
test.AddAttribute<int64_t>("output_width", 4);
test.AddAttribute<int64_t>("sampling_ratio", 2);
test.AddAttribute<float>("spatial_scale", 1.0f / 16.0f);
test.AddAttribute<std::string>("coordinate_transformation_mode", "half_pixel");

constexpr int N = 1;
constexpr int C = 3;
constexpr int H = 5;
constexpr int W = 5;

std::vector<float> rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
test.AddInput<float>("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});

test.Run();
}

TEST(RoiAlignTest, AvgModePositive_output_half_pixel) {
// TODO: Unskip when fixed ort issue #3428
fdwr marked this conversation as resolved.
Show resolved Hide resolved
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.95832991600036621, which exceeds threshold";
}

OpTester test("RoiAlign", 16);
test.AddAttribute<int64_t>("output_height", 3);
test.AddAttribute<int64_t>("output_width", 4);
test.AddAttribute<int64_t>("sampling_ratio", 2);
test.AddAttribute<float>("spatial_scale", 1.0f / 16.0f);
test.AddAttribute<std::string>("coordinate_transformation_mode", "output_half_pixel");

constexpr int N = 1;
constexpr int C = 3;
constexpr int H = 5;
constexpr int W = 5;

std::vector<float> rois{0., 7., 5., 7., 5., 0., -15., -15., -15., -15., 0., -10., 21., -10., 21., 0., 13., 8., 13., 8., 0., -14., 19., -14., 19.};
test.AddInput<float>("X", {N, C, H, W}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74.});
test.AddInput<float>("rois", {5, 4}, {7., 5., 7., 5., -15., -15., -15., -15., -10., 21., -10., 21., 13., 8., 13., 8., -14., 19., -14., 19.});
test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});

test.Run();
}

template <typename T>
static void BasicTest() {
OpTester test("RoiAlign", 10);
Expand Down Expand Up @@ -235,7 +289,7 @@ static void BasicTest() {
}

TEST(RoiAlignTest, OnnxTest) {
// TODO: Unskip when fixed #41968513
// TODO: Unskip when fixed ort issue #3428
fdwr marked this conversation as resolved.
Show resolved Hide resolved
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.051382988691329956, which exceeds threshold";
}
Expand All @@ -245,7 +299,7 @@ TEST(RoiAlignTest, OnnxTest) {
}

TEST(RoiAlignTest, MaxModePositive) {
// TODO: Unskip when fixed #41968513
// TODO: Unskip when fixed ort issue #3428
fdwr marked this conversation as resolved.
Show resolved Hide resolved
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 2.1093800067901611, which exceeds threshold";
}
Expand Down