From 96df4429d9a56512663b643b42c8a695fd5135d0 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Sun, 27 Oct 2024 23:31:35 +0000 Subject: [PATCH 1/5] [mlir][affine] Add static basis support to affine.delinearize This commit makes `affine.delinealize` join other indexing operators, like `vector.extract`, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the `basis` (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants. This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of `OpFoldResult`s and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax). --- .../mlir/Dialect/Affine/IR/AffineOps.h | 2 +- .../mlir/Dialect/Affine/IR/AffineOps.td | 21 ++++- mlir/include/mlir/Dialect/Affine/Utils.h | 3 + .../mlir/Interfaces/ViewLikeInterface.h | 16 ++++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 64 +++++++++----- .../Transforms/AffineExpandIndexOps.cpp | 5 +- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 38 +++++++++ .../AffineToStandard/lower-affine.mlir | 83 +++++++++---------- .../Affine/affine-expand-index-ops.mlir | 5 +- mlir/test/Dialect/Affine/canonicalize.mlir | 11 +-- mlir/test/Dialect/Affine/loop-coalescing.mlir | 27 ++---- mlir/test/Dialect/Affine/ops.mlir | 7 ++ .../extract-slice-from-collapse-shape.mlir | 30 ++----- .../Vector/vector-warp-distribute.mlir | 2 +- mlir/test/python/dialects/affine.py | 3 +- 15 files changed, 186 insertions(+), 131 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h index 5c75e102c3d404..7c950623f77f48 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -16,11 +16,11 @@ #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" - namespace mlir { namespace affine { diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 8773fc5881461a..f53b5d97a7156a 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1084,17 +1084,32 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", ``` }]; - let arguments = (ins Index:$linear_index, Variadic:$basis); + let arguments = (ins Index:$linear_index, + Variadic:$dynamic_basis, + DenseI64ArrayAttr:$static_basis); let results = (outs Variadic:$multi_index); let assemblyFormat = [{ - $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index) + $linear_index `into` ` ` + custom($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren") + attr-dict `:` type($multi_index) }]; let builders = [ - OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> + OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>, + OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)>, + OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> ]; + let extraClassDeclaration = [{ + /// Return a vector with all the static and dynamic basis values. + SmallVector getMixedBasis() { + OpBuilder builder(getContext()); + return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder); + } + + }]; + let hasVerifier = 1; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index 9a2767e0ad87f3..d2cfbaa85a60ef 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -311,6 +311,9 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); FailureOr> delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis); +FailureOr> delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef basis); // Generate IR that extracts the linear index from a multi-index according to // a basis/shape. OpFoldResult linearizeIndex(ArrayRef multiIndex, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index d6479143a0a50b..3dcbd2f1af1936 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -109,6 +109,13 @@ void printDynamicIndexList( ArrayRef integers, ArrayRef scalables, TypeRange valueTypes = TypeRange(), AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); +inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, + OperandRange values, + ArrayRef integers, + AsmParser::Delimiter delimiter) { + return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(), + delimiter); +} inline void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, TypeRange valueTypes = TypeRange(), @@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList( DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); +inline ParseResult +parseDynamicIndexList(OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, + AsmParser::Delimiter delimiter) { + DenseBoolArrayAttr scalableVals = {}; + return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr, + delimiter); +} inline ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 5e7a6b6ca883c3..f384f454bc4726 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" @@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes( RegionRange regions, SmallVectorImpl &inferredReturnTypes) { AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties, regions); - inferredReturnTypes.assign(adaptor.getBasis().size(), + inferredReturnTypes.assign(adaptor.getStaticBasis().size(), IndexType::get(context)); return success(); } -void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, +void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, + OperationState &odsState, + Value linearIndex, ValueRange basis) { + SmallVector dynamicBasis; + SmallVector staticBasis; + dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, + staticBasis); + build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis); +} + +void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, + OperationState &odsState, Value linearIndex, ArrayRef basis) { - result.addTypes(SmallVector(basis.size(), builder.getIndexType())); - result.addOperands(linearIndex); - SmallVector basisValues = - llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value { - std::optional staticDim = getConstantIntValue(ofr); - if (staticDim.has_value()) - return builder.create(result.location, - *staticDim); - return llvm::dyn_cast_if_present(ofr); - }); - result.addOperands(basisValues); + SmallVector dynamicBasis; + SmallVector staticBasis; + dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); + build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis); +} + +void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, + OperationState &odsState, + Value linearIndex, + ArrayRef basis) { + build(odsBuilder, odsState, linearIndex, ValueRange{}, basis); } LogicalResult AffineDelinearizeIndexOp::verify() { - if (getBasis().empty()) + if (getStaticBasis().empty()) return emitOpError("basis should not be empty"); - if (getNumResults() != getBasis().size()) + if (getNumResults() != getStaticBasis().size()) return emitOpError("should return an index for each basis element"); + auto dynamicMarkersCount = + llvm::count_if(getStaticBasis(), ShapedType::isDynamic); + if (static_cast(dynamicMarkersCount) != getDynamicBasis().size()) + return emitOpError( + "mismatch between dynamic and static basis (kDynamic marker but no " + "corresponding dynamic basis entry) -- this can only happen due to an " + "incorrect fold/rewrite"); return success(); } @@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis // Replace all indices corresponding to unit-extent basis with 0. // Remaining basis can be used to get a new `affine.delinearize_index` op. - SmallVector newOperands; - for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) { - if (matchPattern(basis, m_One())) + SmallVector newOperands; + for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) { + std::optional basisVal = getConstantIntValue(basis); + if (basisVal && *basisVal == 1) replacements[index] = getZero(); else newOperands.push_back(basis); } - if (newOperands.size() == delinearizeOp.getBasis().size()) + if (newOperands.size() == delinearizeOp.getStaticBasis().size()) return failure(); if (!newOperands.empty()) { @@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, PatternRewriter &rewriter) const override { - auto basis = delinearizeOp.getBasis(); - if (basis.size() != 1) + if (delinearizeOp.getStaticBasis().size() != 1) return failure(); + auto basis = delinearizeOp.getMixedBasis(); // Check that the `linear_index` is an induction variable. auto inductionVar = dyn_cast(delinearizeOp.getLinearIndex()); @@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop // Check that the upper-bound is the basis. auto upperBounds = loopLikeOp.getLoopUpperBounds(); if (!upperBounds || upperBounds->size() != 1 || - upperBounds->front() != getAsOpFoldResult(basis.front())) { + upperBounds->front() != basis.front()) { return rewriter.notifyMatchFailure(delinearizeOp, "`basis` is not upper bound"); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index c6bc3862256a75..d76968d3a71520 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -35,9 +35,8 @@ struct LowerDelinearizeIndexOps using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, PatternRewriter &rewriter) const override { - FailureOr> multiIndex = - delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), - llvm::to_vector(op.getBasis())); + FailureOr> multiIndex = delinearizeIndex( + rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis()); if (failed(multiIndex)) return failure(); rewriter.replaceOp(op, *multiIndex); diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 910ad1733d03e8..e3b5d26e0ec3c3 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1944,6 +1944,18 @@ static FailureOr getIndexProduct(OpBuilder &b, Location loc, return result; } +static FailureOr getIndexProduct(OpBuilder &b, Location loc, + ArrayRef set) { + if (set.empty()) + return failure(); + OpFoldResult result = set[0]; + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + for (unsigned i = 1, e = set.size(); i < e; i++) + result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]}); + return result; +} + FailureOr> mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis) { @@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, return results; } +FailureOr> +mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, + ArrayRef basis) { + unsigned numDims = basis.size(); + + SmallVector divisors; + for (unsigned i = 1; i < numDims; i++) { + ArrayRef slice = basis.drop_front(i); + FailureOr prod = getIndexProduct(b, loc, slice); + if (failed(prod)) + return failure(); + divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); + } + + SmallVector results; + results.reserve(divisors.size() + 1); + Value residual = linearIndex; + for (Value divisor : divisors) { + DivModValue divMod = getDivMod(b, loc, residual, divisor); + results.push_back(divMod.quotient); + residual = divMod.remainder; + } + results.push_back(residual); + return results; +} + OpFoldResult mlir::affine::linearizeIndex(ArrayRef multiIndex, ArrayRef basis, ImplicitLocOpBuilder &builder) { diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 23e0edd510cbb1..298e82df4f4cea 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -931,53 +931,48 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me /////////////////////////////////////////////////////////////////////// func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) { - %b0 = arith.constant 16 : index - %b1 = arith.constant 224 : index - %b2 = arith.constant 224 : index - %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index + %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index return %1#0, %1#1, %1#2 : index, index, index } // CHECK-LABEL: func.func @test_dilinearize_index( // CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) { -// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 50176 : index // CHECK: %[[VAL_2:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_7:.*]] = arith.constant -1 : index -// CHECK: %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index -// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index -// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index -// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index -// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index -// CHECK: %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index -// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index -// CHECK: %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index -// CHECK: %[[VAL_20:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index -// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index -// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index -// CHECK: %[[VAL_26:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_27:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_28:.*]] = arith.constant -1 : index -// CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index -// CHECK: %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index -// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index -// CHECK: %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index -// CHECK: %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index -// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index -// CHECK: %[[VAL_35:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index -// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index -// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index -// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index -// CHECK: return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index +// CHECK: %[[VAL_3:.*]] = arith.constant 50176 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant -1 : index +// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index +// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index +// CHECK: %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index +// CHECK: %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index +// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index +// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index +// CHECK: %[[VAL_12:.*]] = arith.constant 50176 : index +// CHECK: %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index +// CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index +// CHECK: %[[VAL_18:.*]] = arith.constant 50176 : index +// CHECK: %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index +// CHECK: %[[VAL_20:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index +// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index +// CHECK: %[[VAL_24:.*]] = arith.constant 224 : index +// CHECK: %[[VAL_25:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_26:.*]] = arith.constant -1 : index +// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index +// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index +// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index +// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index +// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index +// CHECK: %[[VAL_33:.*]] = arith.constant 224 : index +// CHECK: %[[VAL_34:.*]] = arith.remsi %[[VAL_0]], %[[VAL_33]] : index +// CHECK: %[[VAL_35:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_34]], %[[VAL_35]] : index +// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_33]] : index +// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index +// CHECK: return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index // CHECK: } diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir index 70b7f397ad4fec..95773206a521e6 100644 --- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -11,10 +11,7 @@ // CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]] // CHECK: return %[[N]], %[[P]], %[[Q]] func.func @static_basis(%linear_index: index) -> (index, index, index) { - %b0 = arith.constant 16 : index - %b1 = arith.constant 224 : index - %b2 = arith.constant 224 : index - %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index + %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index return %1#0, %1#1, %1#2 : index, index, index } diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index 906ae81c76d115..d78c3b667589b8 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1472,7 +1472,7 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () { func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) -> (index, index, index, index, index, index) { %c1 = arith.constant 1 : index - %0:6 = affine.delinearize_index %arg0 into (%c1, %arg1, %c1, %c1, %arg2, %c1) + %0:6 = affine.delinearize_index %arg0 into (1, %arg1, 1, 1, %arg2, %c1) : index, index, index, index, index, index return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index } @@ -1487,8 +1487,7 @@ func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : // ----- func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) { - %c1 = arith.constant 1 : index - %0:2 = affine.delinearize_index %arg0 into (%c1, %c1) : index, index + %0:2 = affine.delinearize_index %arg0 into (1, 1) : index, index return %0#0, %0#1 : index, index } // CHECK-LABEL: func @drop_all_unit_bases( @@ -1519,9 +1518,8 @@ func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index { // CHECK-LABEL: func @delinearize_non_induction_variable func.func @delinearize_non_induction_variable(%arg0: memref, %i : index, %t0 : index, %t1 : index, %t2 : index) -> index { - %c1024 = arith.constant 1024 : index %1 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%i)[%t0, %t1, %t2] - %2 = affine.delinearize_index %1 into (%c1024) : index + %2 = affine.delinearize_index %1 into (1024) : index return %2 : index } @@ -1529,7 +1527,6 @@ func.func @delinearize_non_induction_variable(%arg0: memref, %i : index, // CHECK-LABEL: func @delinearize_non_loop_like func.func @delinearize_non_loop_like(%arg0: memref, %i : index) -> index { - %c1024 = arith.constant 1024 : index - %2 = affine.delinearize_index %i into (%c1024) : index + %2 = affine.delinearize_index %i into (1024) : index return %2 : index } diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir index f6e7b21bc66aba..3be14eaf5c3261 100644 --- a/mlir/test/Dialect/Affine/loop-coalescing.mlir +++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir @@ -6,9 +6,6 @@ func.func @one_3d_nest() { // upper bound is also the number of iterations. // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0 // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1 - // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3 - // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42 - // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56 // CHECK-DAG: %[[range:.*]] = arith.constant 7056 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -25,7 +22,7 @@ func.func @one_3d_nest() { // Reconstruct original IVs from the linearized one. // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] - // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]]) + // CHECK-SAME: into (42, 56, 3) scf.for %j = %c0 to %c56 step %c1 { scf.for %k = %c0 to %c3 step %c1 { // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2) @@ -73,11 +70,6 @@ func.func @unnormalized_loops() { // Normalized lower bound and step for the outer scf. // CHECK-DAG: %[[lb_i:.*]] = arith.constant 0 // CHECK-DAG: %[[step_i:.*]] = arith.constant 1 - // CHECK-DAG: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3 - - // Number of iterations in the inner loop, the pattern is the same as above, - // only capture the final result. - // CHECK-DAG: %[[numiter_j:.*]] = arith.constant 4 // CHECK-DAG: %[[range:.*]] = arith.constant 12 @@ -97,7 +89,7 @@ func.func @unnormalized_loops() { scf.for %j = %c7 to %c17 step %c3 { // The IVs are rewritten. // CHECK: %[[delinearize:.+]]:2 = affine.delinearize_index %[[i]] - // CHECK-SAME: into (%[[orig_step_j_and_numiter_i]], %[[numiter_j]]) + // CHECK-SAME: into (3, 4) // CHECK: %[[orig_j:.*]] = affine.apply affine_map<(d0) -> (d0 * 3 + 7)>(%[[delinearize]]#1) // CHECK: %[[orig_i:.*]] = affine.apply affine_map<(d0) -> (d0 * 2 + 5)>(%[[delinearize]]#0) // CHECK: "use"(%[[orig_i]], %[[orig_j]]) @@ -111,10 +103,7 @@ func.func @unnormalized_loops() { func.func @noramalized_loops_with_yielded_iter_args() { // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0 - // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42 // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1 - // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56 - // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3 // CHECK-DAG: %[[range:.*]] = arith.constant 7056 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -130,7 +119,7 @@ func.func @noramalized_loops_with_yielded_iter_args() { // CHECK-NOT: scf.for // Reconstruct original IVs from the linearized one. - // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]]) + // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] into (42, 56, 3) %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){ %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) { // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2) @@ -150,9 +139,6 @@ func.func @noramalized_loops_with_yielded_iter_args() { func.func @noramalized_loops_with_shuffled_yielded_iter_args() { // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0 // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1 - // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3 - // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42 - // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c3 = arith.constant 3 : index @@ -169,7 +155,7 @@ func.func @noramalized_loops_with_shuffled_yielded_iter_args() { // Reconstruct original IVs from the linearized one. // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] - // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]]) + // CHECK-SAME: into (42, 56, 3) %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){ %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) { // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2) @@ -189,9 +175,6 @@ func.func @noramalized_loops_with_shuffled_yielded_iter_args() { func.func @noramalized_loops_with_yielded_non_iter_args() { // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0 // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1 - // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3 - // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42 - // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c3 = arith.constant 3 : index @@ -208,7 +191,7 @@ func.func @noramalized_loops_with_yielded_non_iter_args() { // Reconstruct original IVs from the linearized one. // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] - // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]]) + // CHECK-SAME: into (42, 56, 3) %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){ %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) { // CHECK: %[[res:.*]] = "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2) diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir index 19ae1584842aea..52ae53adcea9f9 100644 --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -275,3 +275,10 @@ func.func @delinearize(%linear_idx: index, %basis0: index, %basis1 :index) -> (i %1:2 = affine.delinearize_index %linear_idx into (%basis0, %basis1) : index, index return %1#0, %1#1 : index, index } + +// CHECK-LABEL: @delinearize_mixed +func.func @delinearize_mixed(%linear_idx: index, %basis1: index) -> (index, index, index) { + // CHECK: affine.delinearize_index %{{.+}} into (2, %{{.+}}, 3) : index, index, index + %1:3 = affine.delinearize_index %linear_idx into (2, %basis1, 3) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir index 3669cae87408df..4bb099e3401ecf 100644 --- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -11,12 +11,9 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index -// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index // CHECK-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32> // CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) -// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (3, 5, 7 // CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : // CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : // CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : @@ -24,12 +21,9 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3 // CHECK: return %[[tile]] // FOREACH: func.func @extract_slice_static(%[[arg0:.+]]: -// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index -// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index -// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index // FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32> // FOREACH: %[[tile:.+]] = scf.forall (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]]) -// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] +// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (3, 5, 7 // FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : // FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : // FOREACH: in_parallel @@ -50,13 +44,10 @@ func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor< // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index -// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index // CHECK: %[[init:.+]] = tensor.empty() : tensor<10x5xf32> // CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) // CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]]) -// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (3, 5, 7 // CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : // CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : // CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : @@ -78,13 +69,12 @@ func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %si // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index // CHECK: %[[init:.+]] = tensor.empty(%[[sz]]) : tensor // CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> // CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> // CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) // CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]] -// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (3, %[[d1]], %[[d2]]) : // CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : // CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : // CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : @@ -105,9 +95,7 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index // CHECK: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor // CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : // CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : @@ -115,9 +103,9 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0 // CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]]) // CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]]) // CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]] -// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (3, %[[d1]], %[[d2]]) : // CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]] -// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) : +// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (11, %[[d4]]) : // CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : // CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : // CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : @@ -129,18 +117,16 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0 // FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) // FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index // FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index -// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index // FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index -// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index // FOREACH: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor // FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : // FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : // FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : // FOREACH: %[[tile1:.+]] = scf.forall (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]]) // FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]] -// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) : +// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (3, %[[d1]], %[[d2]]) : // FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]] -// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) : +// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (11, %[[d4]]) : // FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : // FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : // FOREACH: in_parallel diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 0544cef3e38281..3acddd6e54639e 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1615,6 +1615,6 @@ func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) { // CHECK-DIST-AND-PROP-SAME: vector<4x1024xf32> // CHECK-DIST-AND-PROP: } -// CHECK-DIST-AND-PROP: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index +// CHECK-DIST-AND-PROP: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (4, 8) : index, index // CHECK-DIST-AND-PROP: %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1] // CHECK-DIST-AND-PROP: vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32> diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py index 6f39e1348fcd57..58be05a8eb7917 100644 --- a/mlir/test/python/dialects/affine.py +++ b/mlir/test/python/dialects/affine.py @@ -47,11 +47,10 @@ def affine_store_test(arg0): # CHECK-LABEL: TEST: testAffineDelinearizeInfer @constructAndPrintInModule def testAffineDelinearizeInfer(): - # CHECK: %[[C0:.*]] = arith.constant 0 : index c0 = arith.ConstantOp(T.index(), 0) # CHECK: %[[C1:.*]] = arith.constant 1 : index c1 = arith.ConstantOp(T.index(), 1) - # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (%[[C1:.*]], %[[C0:.*]]) : index, index + # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (1, 0) : index, index two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c0]) From ad15902093d7d1d7f54b747bd64555c11f491efa Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 28 Oct 2024 22:37:31 +0000 Subject: [PATCH 2/5] Adjust Python test --- mlir/test/python/dialects/affine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py index 58be05a8eb7917..73864708b2b221 100644 --- a/mlir/test/python/dialects/affine.py +++ b/mlir/test/python/dialects/affine.py @@ -47,11 +47,10 @@ def affine_store_test(arg0): # CHECK-LABEL: TEST: testAffineDelinearizeInfer @constructAndPrintInModule def testAffineDelinearizeInfer(): - c0 = arith.ConstantOp(T.index(), 0) # CHECK: %[[C1:.*]] = arith.constant 1 : index c1 = arith.ConstantOp(T.index(), 1) - # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (1, 0) : index, index - two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c0]) + # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index + two_indices = affine.AffineDelinearizeIndexOp(c1, [], [2, 3]) # CHECK-LABEL: TEST: testAffineLoadOp From f3696112de92a381b546ad6bd4f490fe084c991f Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 30 Oct 2024 15:41:10 +0000 Subject: [PATCH 3/5] Address review feedback, update Python test --- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 62 ++++++++----------- .../AffineToStandard/lower-affine.mlir | 4 +- mlir/test/python/dialects/affine.py | 2 +- 3 files changed, 28 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index e3b5d26e0ec3c3..2680502bb687d3 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1931,49 +1931,36 @@ DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs, return result; } -/// Create IR that computes the product of all elements in the set. -static FailureOr getIndexProduct(OpBuilder &b, Location loc, - ArrayRef set) { - if (set.empty()) - return failure(); - OpFoldResult result = set[0]; - AffineExpr s0, s1; - bindSymbols(b.getContext(), s0, s1); - for (unsigned i = 1, e = set.size(); i < e; i++) - result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]}); - return result; -} - -static FailureOr getIndexProduct(OpBuilder &b, Location loc, - ArrayRef set) { - if (set.empty()) - return failure(); - OpFoldResult result = set[0]; +/// Create an affine map that computes `lhs` * `rhs`, composing in any other +/// affine maps. +static FailureOr composedAffineMultiply(OpBuilder &b, + Location loc, + OpFoldResult lhs, + OpFoldResult rhs) { AffineExpr s0, s1; bindSymbols(b.getContext(), s0, s1); - for (unsigned i = 1, e = set.size(); i < e; i++) - result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]}); - return result; + return makeComposedFoldedAffineApply(b, loc, s0 * s1, {lhs, rhs}); } FailureOr> mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis) { - unsigned numDims = basis.size(); - + // Note: the divisors are backwards due to the scan. SmallVector divisors; - for (unsigned i = 1; i < numDims; i++) { - ArrayRef slice = basis.drop_front(i); - FailureOr prod = getIndexProduct(b, loc, slice); - if (failed(prod)) + OpFoldResult basisProd = b.getIndexAttr(1); + for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) { + FailureOr nextProd = + composedAffineMultiply(b, loc, basisElem, basisProd); + if (failed(nextProd)) return failure(); - divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); + basisProd = *nextProd; + divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd)); } SmallVector results; results.reserve(divisors.size() + 1); Value residual = linearIndex; - for (Value divisor : divisors) { + for (Value divisor : llvm::reverse(divisors)) { DivModValue divMod = getDivMod(b, loc, residual, divisor); results.push_back(divMod.quotient); residual = divMod.remainder; @@ -1985,21 +1972,22 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, FailureOr> mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis) { - unsigned numDims = basis.size(); - + // Note: the divisors are backwards due to the scan. SmallVector divisors; - for (unsigned i = 1; i < numDims; i++) { - ArrayRef slice = basis.drop_front(i); - FailureOr prod = getIndexProduct(b, loc, slice); - if (failed(prod)) + OpFoldResult basisProd = b.getIndexAttr(1); + for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) { + FailureOr nextProd = + composedAffineMultiply(b, loc, basisElem, basisProd); + if (failed(nextProd)) return failure(); - divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); + basisProd = *nextProd; + divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd)); } SmallVector results; results.reserve(divisors.size() + 1); Value residual = linearIndex; - for (Value divisor : divisors) { + for (Value divisor : llvm::reverse(divisors)) { DivModValue divMod = getDivMod(b, loc, residual, divisor); results.push_back(divMod.quotient); residual = divMod.remainder; diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 298e82df4f4cea..3781d510897f8f 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -936,8 +936,8 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) } // CHECK-LABEL: func.func @test_dilinearize_index( // CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) { -// CHECK: %[[VAL_1:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_2:.*]] = arith.constant 224 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 224 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 50176 : index // CHECK: %[[VAL_3:.*]] = arith.constant 50176 : index // CHECK: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK: %[[VAL_5:.*]] = arith.constant -1 : index diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py index 73864708b2b221..0dc69d7ba522de 100644 --- a/mlir/test/python/dialects/affine.py +++ b/mlir/test/python/dialects/affine.py @@ -157,7 +157,7 @@ def testAffineForOpErrors(): ) try: - two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c1]) + two_indices = affine.AffineDelinearizeIndexOp(c1, [], [1, 1]) affine.AffineForOp( two_indices, c2, From 06bad6d474535a3f81dfca6b8bc5bc4224125324 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 4 Nov 2024 12:34:03 -0600 Subject: [PATCH 4/5] Formatting fixes Co-authored-by: Jakub Kuderski --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index f53b5d97a7156a..e9480d30c2d701 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1102,12 +1102,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", ]; let extraClassDeclaration = [{ - /// Return a vector with all the static and dynamic basis values. + /// Returns a vector with all the static and dynamic basis values. SmallVector getMixedBasis() { OpBuilder builder(getContext()); return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder); } - }]; let hasVerifier = 1; From 3635e884326aa3315d4c64677bb21392f4fafa30 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 4 Nov 2024 18:35:19 +0000 Subject: [PATCH 5/5] Style nit --- mlir/include/mlir/Dialect/Affine/Utils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index d2cfbaa85a60ef..a2bf92323be01b 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -311,6 +311,7 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); FailureOr> delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis); + FailureOr> delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis);