Skip to content

Commit

Permalink
Support select_last_index attribute of onnx argmin op
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Apr 23, 2024
1 parent 797e4cd commit 2849876
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 17 deletions.
32 changes: 23 additions & 9 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,24 +137,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.s64BoolAttr(selectLastIndex, "select_last_index", false))
return failure();

if (selectLastIndex) {
// TODO: Figure out how to support this case. Need to add a reverse
// or something.
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: select_last_index=true");
}

// ONNX allows negative axis.
auto operandSizes =
cast<Torch::ValueTensorType>(operand.getType()).getSizes();
if (axis < 0)
axis +=
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
axis += operandSizes.size();

Value constAxis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
Value constKeepDims = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(keepDims));

if (selectLastIndex) {
Value dims = createConstantIntList(binder, rewriter, {axis});
auto operandTy = dyn_cast<Torch::ValueTensorType>(operand.getType());
operand = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), operandTy, operand, dims);
Value argmin = rewriter.create<Torch::AtenArgminOp>(
binder.getLoc(), resultType, operand, constAxis, constKeepDims);
Value offset = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(operandSizes[axis] - 1));
Value alpha = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, argmin, offset, alpha);
rewriter.replaceOpWithNewOp<Torch::AtenAbsOp>(binder.op, resultType,
sub);
return success();
}

rewriter.replaceOpWithNewOp<Torch::AtenArgminOp>(
binder.op, resultType, operand, constAxis, constKeepDims);
return success();
Expand Down
34 changes: 34 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2

// -----

func.func @test_argmin_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,4],f32>
// CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64>
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64>
// CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64>
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64>
return %0 : !torch.vtensor<[2,3,1],si64>
}

// -----

// CHECK-LABEL: @test_argmin_no_keepdims_example
func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 1
Expand All @@ -118,6 +135,23 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) ->

// -----

func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,2],f32>, !torch.list<int> -> !torch.vtensor<[2,2],f32>
// CHECK: %[[ARGMIN:.*]] = torch.aten.argmin %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64>
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMIN]], %[[C1_1]], %[[C1_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2],si64>
// CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64>
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}

// -----

// CHECK-LABEL: @test_atan
func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
Expand Down
8 changes: 0 additions & 8 deletions test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,3 @@ module {
return %0 : !torch.vtensor<[2,4],si64>
}
}

// -----
func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// TODO: Unsupported torch.onnx.select_last_index
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}

0 comments on commit 2849876

Please sign in to comment.