Skip to content

Commit

Permalink
User/linneamay/roi align 16 (#15812)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add registration for DML RoiAlign-16 and tests for new
coordinate_transform_mode attribute. PR
[7354](#7354) is still open
to fix the CPU EP version, which is why there are skipped tests right
now. That will be completed separately so that, for now, we can
officially support opset16 with the next release.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Linnea May <[email protected]>
Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
3 people authored and fs-eire committed May 12, 2023
1 parent 60edd50 commit dc7a8b5
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 12 deletions.
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,8 @@ Do not modify directly.*
|||11+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|||10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|STFT|*in* signal:**T1**<br> *in* frame_step:**T2**<br> *in* window:**T1**<br> *in* frame_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,25 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
{"max", DML_REDUCE_FUNCTION_MAX},
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
};

constexpr NameAndIndex coordinateTransformationModes[] =
{
{"half_pixel", 0},
{"output_half_pixel", 1},
};

std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::CoordinateTransformationMode, "half_pixel");
auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
const auto optionalReductionFunction = TryMapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode.");


DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ROITensor = &inputDescs[1];
operatorDesc.BatchIndicesTensor = &inputDescs[2];
Expand All @@ -48,12 +59,15 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = *optionalReductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f;
operatorDesc.OutputPixelOffset = -0.5f;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc };

SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel<DmlOperatorRegionOfInterestAlign, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel<DmlOperatorRegionOfInterestAlign, 16>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
Expand Down Expand Up @@ -551,6 +552,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO_VER( 16, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
Expand Down Expand Up @@ -807,10 +809,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 14, Relu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)},
{REG_INFO( 16, PRelu, typeNameListDefault, supportedTypeListFloat16to32SignedInts8to32, DmlGraphSupport::Supported)}, // bfloat added to T in 16
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,7 @@ using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper<RoiAlignHelper, 10>;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper<RoiAlignHelper, 16>;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ namespace OperatorHelper
static const int sc_sinceVer_LessOrEqual = 16;
static const int sc_sinceVer_ScatterND = 16;
static const int sc_sinceVer_ScatterElements = 16;
static const int sc_sinceVer_RoiAlign = 16;
} // namespace OnnxOperatorSet16

namespace OnnxOperatorSet17
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
{"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."},
{"test_roialign_aligned_true", "Opset 16 not supported yet."},
{"test_roialign_aligned_false", "Opset 16 not supported yet."},
{"test_roialign_mode_max", "Onnx roialign mode expected output is incorrect."},
{"test_scatternd_add", "Opset 16 not supported yet."},
{"test_scatternd_multiply", "Opset 16 not supported yet."},
{"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."},
Expand Down
Loading

0 comments on commit dc7a8b5

Please sign in to comment.