From 4be941209d409d40abd1ba71d1d9d816c262aa71 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Tue, 23 Apr 2024 02:25:56 -0700 Subject: [PATCH] Support select_last_index attribute of onnx argmin op --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 32 ++++++++++++----- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 36 +++++++++++++++++++ .../unsupported_simple_ops.mlir | 8 ----- 3 files changed, 59 insertions(+), 17 deletions(-) delete mode 100644 test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 9559e28ee2a..14aa41bef34 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -151,17 +151,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); - if (selectLastIndex) { - // TODO: Figure out how to support this case. Need to add a reverse - // or something. - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: select_last_index=true"); - } - // ONNX allows negative axis. + auto operandSizes = + cast(operand.getType()).getSizes(); if (axis < 0) - axis += - cast(operand.getType()).getSizes().size(); + axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -169,6 +163,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); + + if (selectLastIndex) { + Value dims = createConstantIntList(binder, rewriter, {axis}); + auto operandTy = dyn_cast(operand.getType()); + operand = rewriter.create( + binder.getLoc(), operandTy, operand, dims); + Value argmin = rewriter.create( + binder.getLoc(), resultType, operand, constAxis, constKeepDims); + Value offset = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); + Value alpha = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, argmin, offset, alpha); + rewriter.replaceOpWithNewOp(binder.op, resultType, + sub); + return success(); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index b776145834d..33d8d8f658b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -143,6 +143,24 @@ func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 // ----- +// CHECK-LABEL: @test_argmin_negative_axis_keepdims_random_select_last_index +func.func @test_argmin_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> + return %0 : !torch.vtensor<[2,3,1],si64> +} + +// ----- + // CHECK-LABEL: @test_argmin_no_keepdims_example func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -154,6 +172,24 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> // ----- +// CHECK-LABEL: @test_argmin_no_keepdims_example_select_last_index +func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + // CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C1_1]], %[[C1_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2],si64> + // CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// ----- + // CHECK-LABEL: @test_atan func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir deleted file mode 100644 index 480a7dbb2da..00000000000 --- a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch - -func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // TODO: Unsupported torch.onnx.select_last_index - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} - %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> - return %0 : !torch.vtensor<[2],si64> -}