From ff1d72afc82fc8fcd884957923e0fd6f3d4847dc Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Mon, 3 Oct 2022 13:33:59 -0700 Subject: [PATCH] [MLIR][TORCH] Add TorchToTosa lowering for aten.where.self op --- e2e_testing/xfail_sets.py | 1 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 25 +++++++++++++++++++ .../test_suite/elementwise.py | 24 ++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 17 +++++++++++++ 4 files changed, 67 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 43a1100eabdf..474e85924501 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -455,6 +455,7 @@ "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "_LogSoftmaxModuleStable_basic", + "ElementwiseAtenWhereSelfModule_basic", "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b14337aea806..7e2bd72398ca 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3004,6 +3004,30 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "unimplemented: broadcasts other than same rank or zero ranked tensor."); } + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenWhereSelfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + auto condType = adaptor.condition().getType().dyn_cast(); + if (!condType) + return rewriter.notifyMatchFailure( + op, "Only tensor types condition are currently supported"); + + auto outType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, outType, adaptor.condition(), + adaptor.self(), adaptor.other()); + + return success(); +} + + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, @@ -3829,6 +3853,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenMaxDimOp); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp); diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index bfb0f30357e6..ec0074858160 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -134,6 +134,30 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtenWhereSelfModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 5, 5], torch.bool, True), + ([1, 12, 5, 5], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, a, b, c): + return torch.ops.aten.where(a, b, c) + + +@register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule()) +def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils): + module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), torch.rand(1, 12, 5, 5), torch.rand(())) + + +# ============================================================================== + + class ElementwiseWhereSelfModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 4997d43acad1..a762a3719840 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -913,3 +913,20 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten %0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1> return %0 : !torch.vtensor<[3,5],i1> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.where.self( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,12,5,5],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,5,5],f32> +// CHECK: } +func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { + %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> + return %0 : !torch.vtensor<[1,12,5,5],f32> +} \ No newline at end of file