Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][affine] Add static basis support to affine.delinearize #113846

Merged
merged 5 commits into from
Nov 4, 2024

Conversation

krzysz00
Copy link
Contributor

This commit makes affine.delinealize join other indexing operators, like vector.extract, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the basis (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants.

This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of OpFoldResults and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax).

This commit makes `affine.delinealize` join other indexing operators,
like `vector.extract`, which store a mixed static/dynamic set of
sizes, offsets, or such. In this case, the `basis` (the set of values
that will be used to decompose the linear index) is now stored as an
array of index attributes where the basis is statically known,
eliminating the need to cretae constants.

This commit also adds copies of the delinearize utility in the affine
dialect to allow it to take an array of `OpFoldResult`s and extends te
DynamicIndexList parser/printer to allow specifying the delimiters in
tablegen (this is needed to avoid breaking existing syntax).
@llvmbot
Copy link

llvmbot commented Oct 27, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit makes affine.delinealize join other indexing operators, like vector.extract, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the basis (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants.

This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of OpFoldResults and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax).


Patch is 40.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113846.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+18-3)
  • (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+3)
  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+16)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+42-22)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+38)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+39-44)
  • (modified) mlir/test/Dialect/Affine/affine-expand-index-ops.mlir (+1-4)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+4-7)
  • (modified) mlir/test/Dialect/Affine/loop-coalescing.mlir (+5-22)
  • (modified) mlir/test/Dialect/Affine/ops.mlir (+7)
  • (modified) mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir (+8-22)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+1-1)
  • (modified) mlir/test/python/dialects/affine.py (+1-2)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..7c950623f77f48 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,11 +16,11 @@
 
 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
-
 namespace mlir {
 namespace affine {
 
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..f53b5d97a7156a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1084,17 +1084,32 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
     ```
   }];
 
-  let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+  let arguments = (ins Index:$linear_index,
+    Variadic<Index>:$dynamic_basis,
+    DenseI64ArrayAttr:$static_basis);
   let results = (outs Variadic<Index>:$multi_index);
 
   let assemblyFormat = [{
-    $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index)
+    $linear_index `into` ` `
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($multi_index)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
   ];
 
+  let extraClassDeclaration = [{
+    /// Return a vector with all the static and dynamic basis values.
+    SmallVector<OpFoldResult> getMixedBasis() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
+
+  }];
+
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..d2cfbaa85a60ef 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,6 +311,9 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+                                               Value linearIndex,
+                                               ArrayRef<OpFoldResult> basis);
 // Generate IR that extracts the linear index from a multi-index according to
 // a basis/shape.
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                  OperandRange values,
+                                  ArrayRef<int64_t> integers,
+                                  AsmParser::Delimiter delimiter) {
+  return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+                               delimiter);
+}
 inline void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      DenseI64ArrayAttr &integers,
+                      AsmParser::Delimiter delimiter) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+                               delimiter);
+}
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..f384f454bc4726 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
     RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
   AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
                                           regions);
-  inferredReturnTypes.assign(adaptor.getBasis().size(),
+  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
                              IndexType::get(context));
   return success();
 }
 
-void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex, ValueRange basis) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+                             staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
                                      Value linearIndex,
                                      ArrayRef<OpFoldResult> basis) {
-  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
-  result.addOperands(linearIndex);
-  SmallVector<Value> basisValues =
-      llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
-        std::optional<int64_t> staticDim = getConstantIntValue(ofr);
-        if (staticDim.has_value())
-          return builder.create<arith::ConstantIndexOp>(result.location,
-                                                        *staticDim);
-        return llvm::dyn_cast_if_present<Value>(ofr);
-      });
-  result.addOperands(basisValues);
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex,
+                                     ArrayRef<int64_t> basis) {
+  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
 }
 
 LogicalResult AffineDelinearizeIndexOp::verify() {
-  if (getBasis().empty())
+  if (getStaticBasis().empty())
     return emitOpError("basis should not be empty");
-  if (getNumResults() != getBasis().size())
+  if (getNumResults() != getStaticBasis().size())
     return emitOpError("should return an index for each basis element");
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+    return emitOpError(
+        "mismatch between dynamic and static basis (kDynamic marker but no "
+        "corresponding dynamic basis entry) -- this can only happen due to an "
+        "incorrect fold/rewrite");
   return success();
 }
 
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis
 
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
-    SmallVector<Value> newOperands;
-    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
-      if (matchPattern(basis, m_One()))
+    SmallVector<OpFoldResult> newOperands;
+    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
+      std::optional<int64_t> basisVal = getConstantIntValue(basis);
+      if (basisVal && *basisVal == 1)
         replacements[index] = getZero();
       else
         newOperands.push_back(basis);
     }
 
-    if (newOperands.size() == delinearizeOp.getBasis().size())
+    if (newOperands.size() == delinearizeOp.getStaticBasis().size())
       return failure();
 
     if (!newOperands.empty()) {
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop
 
   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
                                 PatternRewriter &rewriter) const override {
-    auto basis = delinearizeOp.getBasis();
-    if (basis.size() != 1)
+    if (delinearizeOp.getStaticBasis().size() != 1)
       return failure();
+    auto basis = delinearizeOp.getMixedBasis();
 
     // Check that the `linear_index` is an induction variable.
     auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop
     // Check that the upper-bound is the basis.
     auto upperBounds = loopLikeOp.getLoopUpperBounds();
     if (!upperBounds || upperBounds->size() != 1 ||
-        upperBounds->front() != getAsOpFoldResult(basis.front())) {
+        upperBounds->front() != basis.front()) {
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "`basis` is not upper bound");
     }
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..d76968d3a71520 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,9 +35,8 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         llvm::to_vector(op.getBasis()));
+    FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
+        rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
     if (failed(multiIndex))
       return failure();
     rewriter.replaceOp(op, *multiIndex);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..e3b5d26e0ec3c3 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
   return result;
 }
 
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+                                               ArrayRef<OpFoldResult> set) {
+  if (set.empty())
+    return failure();
+  OpFoldResult result = set[0];
+  AffineExpr s0, s1;
+  bindSymbols(b.getContext(), s0, s1);
+  for (unsigned i = 1, e = set.size(); i < e; i++)
+    result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
+  return result;
+}
+
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
                                ArrayRef<Value> basis) {
@@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   return results;
 }
 
+FailureOr<SmallVector<Value>>
+mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
+                               ArrayRef<OpFoldResult> basis) {
+  unsigned numDims = basis.size();
+
+  SmallVector<Value> divisors;
+  for (unsigned i = 1; i < numDims; i++) {
+    ArrayRef<OpFoldResult> slice = basis.drop_front(i);
+    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+    if (failed(prod))
+      return failure();
+    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+  }
+
+  SmallVector<Value> results;
+  results.reserve(divisors.size() + 1);
+  Value residual = linearIndex;
+  for (Value divisor : divisors) {
+    DivModValue divMod = getDivMod(b, loc, residual, divisor);
+    results.push_back(divMod.quotient);
+    residual = divMod.remainder;
+  }
+  results.push_back(residual);
+  return results;
+}
+
 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis,
                                           ImplicitLocOpBuilder &builder) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..298e82df4f4cea 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -931,53 +931,48 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
 ///////////////////////////////////////////////////////////////////////
 
 func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
-  %b0 = arith.constant 16 : index
-  %b1 = arith.constant 224 : index
-  %b2 = arith.constant 224 : index
-  %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 // CHECK-LABEL:   func.func @test_dilinearize_index(
 // CHECK-SAME:                                      %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 50176 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_7:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index
-// CHECK:           %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_14:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index
-// CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index
-// CHECK:           %[[VAL_20:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index
-// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index
-// CHECK:           %[[VAL_26:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_27:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_28:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index
-// CHECK:           %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index
-// CHECK:           %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index
-// CHECK:           %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index
-// CHECK:           %[[VAL_35:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_37:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index
-// CHECK:           %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
-// CHECK:           return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_12:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
+// CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
+// CHECK:           %[[VAL_18:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
+// CHECK:           %[[VAL_20:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
+// CHECK:           %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
+// CHECK:           %[[VAL_24:.*]] = arith.constant 224 : index
+// CHECK:           %[[VAL_25:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
+// CHECK:           %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
+// CHECK:           %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
+// CHECK:           %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
+// CHECK:           %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
+// CHECK:           %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
+...
[truncated]

@llvmbot
Copy link

llvmbot commented Oct 27, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit makes affine.delinealize join other indexing operators, like vector.extract, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the basis (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants.

This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of OpFoldResults and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax).


Patch is 40.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113846.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+18-3)
  • (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+3)
  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+16)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+42-22)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+38)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+39-44)
  • (modified) mlir/test/Dialect/Affine/affine-expand-index-ops.mlir (+1-4)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+4-7)
  • (modified) mlir/test/Dialect/Affine/loop-coalescing.mlir (+5-22)
  • (modified) mlir/test/Dialect/Affine/ops.mlir (+7)
  • (modified) mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir (+8-22)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+1-1)
  • (modified) mlir/test/python/dialects/affine.py (+1-2)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..7c950623f77f48 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,11 +16,11 @@
 
 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
-
 namespace mlir {
 namespace affine {
 
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..f53b5d97a7156a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1084,17 +1084,32 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
     ```
   }];
 
-  let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+  let arguments = (ins Index:$linear_index,
+    Variadic<Index>:$dynamic_basis,
+    DenseI64ArrayAttr:$static_basis);
   let results = (outs Variadic<Index>:$multi_index);
 
   let assemblyFormat = [{
-    $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index)
+    $linear_index `into` ` `
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($multi_index)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
   ];
 
+  let extraClassDeclaration = [{
+    /// Return a vector with all the static and dynamic basis values.
+    SmallVector<OpFoldResult> getMixedBasis() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
+
+  }];
+
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..d2cfbaa85a60ef 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,6 +311,9 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+                                               Value linearIndex,
+                                               ArrayRef<OpFoldResult> basis);
 // Generate IR that extracts the linear index from a multi-index according to
 // a basis/shape.
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                  OperandRange values,
+                                  ArrayRef<int64_t> integers,
+                                  AsmParser::Delimiter delimiter) {
+  return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+                               delimiter);
+}
 inline void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      DenseI64ArrayAttr &integers,
+                      AsmParser::Delimiter delimiter) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+                               delimiter);
+}
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..f384f454bc4726 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
     RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
   AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
                                           regions);
-  inferredReturnTypes.assign(adaptor.getBasis().size(),
+  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
                              IndexType::get(context));
   return success();
 }
 
-void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex, ValueRange basis) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+                             staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
                                      Value linearIndex,
                                      ArrayRef<OpFoldResult> basis) {
-  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
-  result.addOperands(linearIndex);
-  SmallVector<Value> basisValues =
-      llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
-        std::optional<int64_t> staticDim = getConstantIntValue(ofr);
-        if (staticDim.has_value())
-          return builder.create<arith::ConstantIndexOp>(result.location,
-                                                        *staticDim);
-        return llvm::dyn_cast_if_present<Value>(ofr);
-      });
-  result.addOperands(basisValues);
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex,
+                                     ArrayRef<int64_t> basis) {
+  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
 }
 
 LogicalResult AffineDelinearizeIndexOp::verify() {
-  if (getBasis().empty())
+  if (getStaticBasis().empty())
     return emitOpError("basis should not be empty");
-  if (getNumResults() != getBasis().size())
+  if (getNumResults() != getStaticBasis().size())
     return emitOpError("should return an index for each basis element");
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+    return emitOpError(
+        "mismatch between dynamic and static basis (kDynamic marker but no "
+        "corresponding dynamic basis entry) -- this can only happen due to an "
+        "incorrect fold/rewrite");
   return success();
 }
 
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis
 
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
-    SmallVector<Value> newOperands;
-    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
-      if (matchPattern(basis, m_One()))
+    SmallVector<OpFoldResult> newOperands;
+    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
+      std::optional<int64_t> basisVal = getConstantIntValue(basis);
+      if (basisVal && *basisVal == 1)
         replacements[index] = getZero();
       else
         newOperands.push_back(basis);
     }
 
-    if (newOperands.size() == delinearizeOp.getBasis().size())
+    if (newOperands.size() == delinearizeOp.getStaticBasis().size())
       return failure();
 
     if (!newOperands.empty()) {
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop
 
   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
                                 PatternRewriter &rewriter) const override {
-    auto basis = delinearizeOp.getBasis();
-    if (basis.size() != 1)
+    if (delinearizeOp.getStaticBasis().size() != 1)
       return failure();
+    auto basis = delinearizeOp.getMixedBasis();
 
     // Check that the `linear_index` is an induction variable.
     auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop
     // Check that the upper-bound is the basis.
     auto upperBounds = loopLikeOp.getLoopUpperBounds();
     if (!upperBounds || upperBounds->size() != 1 ||
-        upperBounds->front() != getAsOpFoldResult(basis.front())) {
+        upperBounds->front() != basis.front()) {
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "`basis` is not upper bound");
     }
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..d76968d3a71520 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,9 +35,8 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         llvm::to_vector(op.getBasis()));
+    FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
+        rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
     if (failed(multiIndex))
       return failure();
     rewriter.replaceOp(op, *multiIndex);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..e3b5d26e0ec3c3 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
   return result;
 }
 
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+                                               ArrayRef<OpFoldResult> set) {
+  if (set.empty())
+    return failure();
+  OpFoldResult result = set[0];
+  AffineExpr s0, s1;
+  bindSymbols(b.getContext(), s0, s1);
+  for (unsigned i = 1, e = set.size(); i < e; i++)
+    result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
+  return result;
+}
+
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
                                ArrayRef<Value> basis) {
@@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   return results;
 }
 
+FailureOr<SmallVector<Value>>
+mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
+                               ArrayRef<OpFoldResult> basis) {
+  unsigned numDims = basis.size();
+
+  SmallVector<Value> divisors;
+  for (unsigned i = 1; i < numDims; i++) {
+    ArrayRef<OpFoldResult> slice = basis.drop_front(i);
+    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+    if (failed(prod))
+      return failure();
+    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+  }
+
+  SmallVector<Value> results;
+  results.reserve(divisors.size() + 1);
+  Value residual = linearIndex;
+  for (Value divisor : divisors) {
+    DivModValue divMod = getDivMod(b, loc, residual, divisor);
+    results.push_back(divMod.quotient);
+    residual = divMod.remainder;
+  }
+  results.push_back(residual);
+  return results;
+}
+
 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis,
                                           ImplicitLocOpBuilder &builder) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..298e82df4f4cea 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -931,53 +931,48 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
 ///////////////////////////////////////////////////////////////////////
 
 func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
-  %b0 = arith.constant 16 : index
-  %b1 = arith.constant 224 : index
-  %b2 = arith.constant 224 : index
-  %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 // CHECK-LABEL:   func.func @test_dilinearize_index(
 // CHECK-SAME:                                      %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 50176 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_7:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index
-// CHECK:           %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_14:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index
-// CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index
-// CHECK:           %[[VAL_20:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index
-// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index
-// CHECK:           %[[VAL_26:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_27:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_28:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index
-// CHECK:           %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index
-// CHECK:           %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index
-// CHECK:           %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index
-// CHECK:           %[[VAL_35:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_37:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index
-// CHECK:           %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
-// CHECK:           return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_12:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
+// CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
+// CHECK:           %[[VAL_18:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
+// CHECK:           %[[VAL_20:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
+// CHECK:           %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
+// CHECK:           %[[VAL_24:.*]] = arith.constant 224 : index
+// CHECK:           %[[VAL_25:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
+// CHECK:           %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
+// CHECK:           %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
+// CHECK:           %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
+// CHECK:           %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
+// CHECK:           %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
+...
[truncated]

Comment on lines +4554 to +4555
"corresponding dynamic basis entry) -- this can only happen due to an "
"incorrect fold/rewrite");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not construct such an op using Operation::create (rather than builder)? If we cannot, this should be an assertion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The builder construction is along the lines of b.create<affine::DelinearizeIndexOp>(loc, index, {v1, v2}, {2, 5, 7}), which has too many dynamic basis entries, or alternatively (index, {}, {2, kDynamic, 7}) which has too few

@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
return result;
}

static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1990 to 1997
SmallVector<Value> divisors;
for (unsigned i = 1; i < numDims; i++) {
ArrayRef<OpFoldResult> slice = basis.drop_front(i);
FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
if (failed(prod))
return failure();
divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more efficient to do a scan and collect intermediate products than rerun the product every time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (and applied to the existing function), thanks!

@krzysz00 krzysz00 requested a review from ftynse October 31, 2024 16:12
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Oct 31, 2024
`affine.linearize_index` is the inverse of `affine.delinearize_index`
and general useful for representing computations (like those needed to
move from N-D to 1-D memrefs) that put together indices.

This commit introduces `affine.linearize_index` and one simple
canonicalization for it.

There are plans to add `affine.linearize_index` and
`affine.delinearize_index` pair canonicalizations, but we are saving
those for a followup PR (especially since having llvm#113846 landed would
make them nicer).

Note while `affine` may not be the natural home for this operation,
https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13
didn't come to any better consensus location.
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation LGTM

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td Outdated Show resolved Hide resolved
krzysz00 and others added 2 commits November 4, 2024 12:34
Co-authored-by: Jakub Kuderski <[email protected]>
@krzysz00 krzysz00 merged commit 704808c into llvm:main Nov 4, 2024
6 of 7 checks passed
zincnode added a commit to zincnode/mlir-he that referenced this pull request Nov 5, 2024
- Fix Python Bindings LIT based on upstream update
- Reference: llvm/llvm-project#113846
krzysz00 added a commit that referenced this pull request Nov 5, 2024
`affine.linearize_index` is the inverse of `affine.delinearize_index`
and general useful for representing computations (like those needed to
move from N-D to 1-D memrefs) that put together indices.

This commit introduces `affine.linearize_index` and one simple
canonicalization for it.

There are plans to add `affine.linearize_index` and
`affine.delinearize_index` pair canonicalizations, but we are saving
those for a followup PR (especially since having #113846 landed would
make them nicer).

Note while `affine` may not be the natural home for this operation,
https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13
didn't come to any better consensus location.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
PhilippRados pushed a commit to PhilippRados/llvm-project that referenced this pull request Nov 6, 2024
…13846)

This commit makes `affine.delinealize` join other indexing operators,
like `vector.extract`, which store a mixed static/dynamic set of sizes,
offsets, or such. In this case, the `basis` (the set of values that will
be used to decompose the linear index) is now stored as an array of
index attributes where the basis is statically known, eliminating the
need to cretae constants.

This commit also adds copies of the delinearize utility in the affine
dialect to allow it to take an array of `OpFoldResult`s and extends te
DynamicIndexList parser/printer to allow specifying the delimiters in
tablegen (this is needed to avoid breaking existing syntax).

---------

Co-authored-by: Jakub Kuderski <[email protected]>
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Nov 8, 2024
…13846)

This commit makes `affine.delinealize` join other indexing operators,
like `vector.extract`, which store a mixed static/dynamic set of sizes,
offsets, or such. In this case, the `basis` (the set of values that will
be used to decompose the linear index) is now stored as an array of
index attributes where the basis is statically known, eliminating the
need to cretae constants.

This commit also adds copies of the delinearize utility in the affine
dialect to allow it to take an array of `OpFoldResult`s and extends te
DynamicIndexList parser/printer to allow specifying the delimiters in
tablegen (this is needed to avoid breaking existing syntax).

---------

Co-authored-by: Jakub Kuderski <[email protected]>
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Nov 8, 2024
`affine.linearize_index` is the inverse of `affine.delinearize_index`
and general useful for representing computations (like those needed to
move from N-D to 1-D memrefs) that put together indices.

This commit introduces `affine.linearize_index` and one simple
canonicalization for it.

There are plans to add `affine.linearize_index` and
`affine.delinearize_index` pair canonicalizations, but we are saving
those for a followup PR (especially since having llvm#113846 landed would
make them nicer).

Note while `affine` may not be the natural home for this operation,
https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13
didn't come to any better consensus location.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants