From 14ef05a29284924402ea87d6e0754d03f7297509 Mon Sep 17 00:00:00 2001 From: justin-ngo-arm Date: Mon, 16 Sep 2024 12:40:24 -0700 Subject: [PATCH] [TOSA] Extend Torch to TOSA reduction ops legalization (#3710) - Add Torch to TOSA legalization for the following reduction ops: + aten.min.dim + aten.min + aten.max + aten.prod + aten.prod.dim_int + aten.all.dim - Add dtype casting support for reduce sum and prod ops - Extend aten.max.dim legalization to a template to support aten.min.dim legalization - Update end-to-end tests sets in xfail_sets.py Signed-off-by: Justin Ngo Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 210 +++++++++++++++------ projects/pt1/e2e_testing/xfail_sets.py | 140 ++++++++------ test/Conversion/TorchToTosa/basic.mlir | 97 ++++++++++ 3 files changed, 331 insertions(+), 116 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 2bbacaf0015..0dbea2b5c94 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -676,6 +676,53 @@ class ConvertAtenReductionOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Only ranked tensor type outputs permitted for reduce_mean"); + auto selfElemTy = selfTy.getElementType(); + if (!selfElemTy.isIntOrFloat()) + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); + + // TOSA ReduceAll and ReduceAny ops only accept bool input + if constexpr (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { + self = tosa::promoteType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1))); + } + + // Handle dtype output and bool elem type for ReduceSum and ReduceProd ops + if constexpr (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { + auto dtype = op.getDtype(); + int64_t dtypeInt; + if (!isa(dtype.getType())) { + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure(op, "dtype is not a constant int"); + + FailureOr maybeDtypeType = getTypeForScalarType( + op.getContext(), (torch_upstream::ScalarType)dtypeInt); + if (failed(maybeDtypeType)) { + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + } else { + Type dtypeType = maybeDtypeType.value(); + + if (isa(dtypeType)) + dtypeType = + rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth()); + + self = tosa::promoteType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), dtypeType)); + } + } else { + if (selfElemTy.isInteger(1)) + self = tosa::promoteType(rewriter, self, outputTy); + } + } + ElementsAttr reduceDimsAttr; bool keepDims; @@ -3248,81 +3295,104 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenMaxDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); +template +class ConvertAtenMinMaxDimOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - auto indicesType = - dyn_cast(getTypeConverter()->convertType(op.getType(1))); - if (!indicesType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - auto selfElemType = selfType.getElementType(); - auto indicesElemType = indicesType.getElementType(); + const TypeConverter *typeConverter = this->getTypeConverter(); + auto indicesType = + dyn_cast(typeConverter->convertType(op.getType(1))); + if (!indicesType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - // Only statically deducible values are currently supported - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); + auto selfElemType = selfType.getElementType(); + auto indicesElemType = indicesType.getElementType(); - dim = toPositiveDim(dim, selfType.getRank()); + // Only statically deducible values are currently supported + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); - if (!isValidDim(dim, selfType.getRank())) - return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank"); + dim = toPositiveDim(dim, selfType.getRank()); - bool keepDim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) - return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant"); + if (!isValidDim(dim, selfType.getRank())) + return rewriter.notifyMatchFailure(op, + "dim must be less than tensor rank"); - SmallVector reducedShape, prunedShape; - for (auto en : - llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { - if (static_cast(en.index()) == dim) { - reducedShape.push_back(1); - continue; + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure(op, + "keepdim must be a Scalar constant"); + + SmallVector reducedShape, prunedShape; + for (auto en : + llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { + if (static_cast(en.index()) == dim) { + reducedShape.push_back(1); + continue; + } + reducedShape.push_back(en.value()); + prunedShape.push_back(en.value()); } - reducedShape.push_back(en.value()); - prunedShape.push_back(en.value()); - } - - auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); - auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); - Value reduceMax = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), - selfElemType), - adaptor.getSelf(), dimAttr); - - Value argMax = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), - indicesElemType), - adaptor.getSelf(), dimAttr); - - if (argMax.getType() != indicesType) { - argMax = rewriter.create( - op->getLoc(), indicesType, argMax, - rewriter.getDenseI64ArrayAttr(reducedShape)); - } + auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); + auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); - if (!keepDim) { - reduceMax = rewriter.create( + Value reduceOp = rewriter.create( op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), selfElemType), - reduceMax, prunedShapeAttr); - } + self, dimAttr); - rewriter.replaceOp(op, {reduceMax, argMax}); + // To handle ReduceMinDim indices, we apply ArgMaxOp on the negate + // of the input tensor, which will return indices of input's min values + Value argMaxOp; + if constexpr (std::is_same()) { + Value negateOp = + rewriter.create(op->getLoc(), selfType, self); - return success(); -} + argMaxOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + indicesElemType), + negateOp, dimAttr); + } else { + argMaxOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + indicesElemType), + self, dimAttr); + } + + if (argMaxOp.getType() != indicesType) { + argMaxOp = rewriter.create( + op->getLoc(), indicesType, argMaxOp, + rewriter.getDenseI64ArrayAttr(reducedShape)); + } + + if (!keepDim) { + reduceOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + selfElemType), + reduceOp, prunedShapeAttr); + } + + rewriter.replaceOp(op, {reduceOp, argMaxOp}); + + return success(); + } +}; template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -5623,6 +5693,10 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { typeConverter, context); INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ @@ -5635,8 +5709,21 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { mlir::tosa::convertReduceAnyOp) INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, + mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, + mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN +#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); +#undef INSERT_INDICES_REDUCTION_OP_PATTERN + #define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); @@ -5727,7 +5814,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenMaxDimOp); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c99ef4d9687..0bb39ad3bf6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1625,6 +1625,7 @@ TOSA_CRASHING_SET = { # Runtime op verification: Out of bounds access "IndexTensorNegativeIndexModule_basic", + "ReduceAllDimEmpty_basic", } FX_IMPORTER_TOSA_CRASHING_SET = { @@ -1643,6 +1644,36 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_keepDim", + "ReduceAllDimBool_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinFloatModule_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdDtypeIntModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", "AtenTrilStaticModule_basic", "AtenTrilWithNegDiagonalStaticModule_basic", "AtenTrilWithPosDiagonalStaticModule_basic", @@ -2155,6 +2186,39 @@ TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_keepDim", + "ReduceAllDimBool_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinFloatModule_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdDtypeIntModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameModule_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AtenLinear1D_basic", "AtenLinearMatVec_basic", @@ -3038,6 +3102,17 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "AtenPolarDoubleModule_basic", + "AtenPolarFloatModule_basic", + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AtenIntMM_basic", "AtenKthvalueDynamicDimsModule_basic", @@ -3075,16 +3150,11 @@ "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", - "ReduceAminSingleDim_basic", - "ReduceAminmaxAllDims_basic", - "ReduceAminmaxSingleDim_basic", - "ReduceAnyDimFloatModule_basic", "RenormModuleFloat16_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex128IntModule_basic", @@ -3126,11 +3196,6 @@ "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "ArangeStartOutViewModule_basic", - "ArgminIntModule_basic", - "ArgminIntModule_multiple_mins", - "ArgminModule_basic", - "ArgminModule_keepDim", - "ArgminModule_with_dim", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", @@ -3239,7 +3304,6 @@ "ConvolutionModule2DTranspose_basic", "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", - "CrossEntropyLossModule_basic", "CumsumInputDtypeInt32Module_basic", "CumsumModule_basic", "CumsumStaticModule_basic", @@ -3483,9 +3547,7 @@ "LinalgVectorNormComplexModule_basic", "LinspaceDtypeModule_basic", "LinspaceEmptyModule_basic", - "LinspaceModule_basic", "LinspaceOneSizeModule_basic", - "LinspaceTwoSizeModule_basic", "MaskedFillTensorFloatValueModule_basic", "MatmulBroadcastBatchDim_basic", "MatmulStaticBroadcast_basic", @@ -3524,10 +3586,8 @@ "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", - "MeanDimDtypeModule_basic", "MeanDimEmptyDimModule_basic", "MeanDimNoneDimModule_basic", - "MeanDtypeModule_basic", "MseLossMeanReductionModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic", "MulFloatModule_basic", @@ -3566,9 +3626,6 @@ "NllLossModuleBackwardWeight_basic", "NllLossModuleBackward_basic", "NllLossModuleBackward_ignore_index", - "NllLossModule_1D_basic", - "NllLossModule_mean_basic", - "NllLossModule_sum_basic", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", @@ -3613,14 +3670,7 @@ "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", "RandnModule_basic", - "ReduceAllDimBool_basic", "ReduceAllDimEmpty_basic", - "ReduceAllDimFloat_basic", - "ReduceAllDimInt_basic", - "ReduceAllFloatModule_basic", - "ReduceAllIntModule_basic", - "ReduceAnyFloatModule_basic", - "ReduceAnyIntModule_basic", "ReduceFrobeniusNormComplexModule_basic", "ReduceL1NormComplexModule_basic", "ReduceL1NormWithDTypeModule_basic", @@ -3628,34 +3678,9 @@ "ReduceL3NormAllDimsModule_basic", "ReduceL3NormKeepDimComplexModule_basic", "ReduceL3NormKeepDimModule_basic", - "ReduceMaxAllDims_basic", "ReduceMaxAlongDimUnsignedInt_basic", - "ReduceMaxFloatModule_basic", - "ReduceMaxSignedIntModule_basic", - "ReduceMaxUnsignedIntModule_basic", - "ReduceMinAlongDimNegative_basic", - "ReduceMinAlongDimSignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceMinAlongDim_basic", - "ReduceMinFloatModule_basic", - "ReduceMinKeepDimReturnBoth_basic", - "ReduceMinKeepDim_basic", - "ReduceMinSignedIntModule_basic", - "ReduceMinUnsignedIntModule_basic", - "ReduceProdDimIntFloatModule_basic", - "ReduceProdDtypeFloatModule_basic", - "ReduceProdDtypeIntModule_basic", - "ReduceProdElementTypeBoolModule_basic", - "ReduceProdFloatModule_basic", - "ReduceProdSignedIntModule_basic", - "ReduceProdUnsignedIntModule_basic", - "ReduceSumDimIntListDtypeFloatModule_basic", - "ReduceSumDimIntListDtypeIntModule_basic", - "ReduceSumDimIntListElementTypeBoolModule_basic", "ReduceSumDimIntListEmptyDimModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceSumElementTypeBoolModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule3dInput_Left", @@ -3672,7 +3697,6 @@ "ReplicationPad2dModule_top0", "RollModule_basic", "RsubInt0d_NumToTensor_Module_basic", - "RsubIntModule_basic", "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", @@ -3801,6 +3825,17 @@ } ONNX_TOSA_XFAIL_SET = { + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", + "SafeSoftmaxModule_basic", + "SafeSoftmaxNonNoneDtypeModule_basic", "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", @@ -3916,7 +3951,6 @@ "ArgminIntModule_basic", "ArgminIntModule_multiple_mins", "ArgminModule_basic", - "ArgminModule_keepDim", "ArgminModule_with_dim", "AtenComplex64Module_basic", "AtenComplexImagModule_basic", @@ -4162,7 +4196,6 @@ "ElementwiseExpm1Module_basic", "ElementwiseFlattenBroadcastModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic", - "ElementwiseFmodTensor_Float_basic", "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseGeFloatIntScalarModule_basic", @@ -4624,7 +4657,6 @@ "ScalarImplicitIntModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", - "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterReduceFloatMaxModule", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 57bbac29624..c8a3d371fe7 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1373,3 +1373,100 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v %0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32> return %0 : !torch.vtensor<[2,4],si32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.min.dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> +// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32> +// CHECK: } +func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> + %true = torch.constant.bool true + %int2 = torch.constant.int 2 + %values, %indices = torch.aten.min.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64> + %1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> + return %1 : tensor<3x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.min$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.aten.min %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.aten.max %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.prod.dim_int$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,2,1],f32> +// CHECK: } +func.func @torch.aten.prod.dim_int$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> { + %dim = torch.constant.int 2 + %keepdims = torch.constant.bool true + %dtype = torch.constant.none + %0 = torch.aten.prod.dim_int %arg0, %dim, %keepdims, %dtype: !torch.vtensor<[3,2,3],f32> , !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.all.dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = tosa.reduce_all %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xi1>) -> tensor<3x2x1xi1> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x2x1xi1> -> !torch.vtensor<[3,2,1],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,2,1],i1> +// CHECK: } +func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> { + %dim = torch.constant.int 2 + %keepdims = torch.constant.bool true + %0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1> + return %0 : !torch.vtensor<[3,2,1],i1> +}