Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DML EP RoiAlign-16 #15812

Merged
merged 16 commits into from
May 10, 2023
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,8 @@ Do not modify directly.*
|||11+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|||10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|STFT|*in* signal:**T1**<br> *in* frame_step:**T2**<br> *in* window:**T1**<br> *in* frame_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,25 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
{"max", DML_REDUCE_FUNCTION_MAX},
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
};

constexpr NameAndIndex coordinateTransformationModes[] =
{
{"half_pixel", 0},
{"output_half_pixel", 1},
};

std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::CoordinateTransformationMode, "half_pixel");
auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
const auto optionalReductionFunction = TryMapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode.");


DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ROITensor = &inputDescs[1];
operatorDesc.BatchIndicesTensor = &inputDescs[2];
Expand All @@ -48,12 +59,15 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = *optionalReductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f;
operatorDesc.OutputPixelOffset = -0.5f;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc };

SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel<DmlOperatorRegionOfInterestAlign, 16>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
Expand Down Expand Up @@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
Expand Down Expand Up @@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ namespace OperatorHelper
static const int sc_sinceVer_LessOrEqual = 16;
static const int sc_sinceVer_ScatterND = 16;
static const int sc_sinceVer_ScatterElements = 16;
static const int sc_sinceVer_RoiAlign = 16;
} // namespace OnnxOperatorSet16

namespace OnnxOperatorSet17
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"test_roialign_aligned_true", "Opset 16 not supported yet."},
{"test_roialign_aligned_false", "Opset 16 not supported yet."},
{"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."},
{"test_scatternd_add", "Opset 16 not supported yet."},
{"test_scatternd_multiply", "Opset 16 not supported yet."},
{"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."},
Expand Down
Loading