-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][affine] Add static basis support to affine.delinearize #113846
[mlir][affine] Add static basis support to affine.delinearize #113846
Conversation
This commit makes `affine.delinealize` join other indexing operators, like `vector.extract`, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the `basis` (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants. This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of `OpFoldResult`s and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax).
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-affine Author: Krzysztof Drewniak (krzysz00) ChangesThis commit makes This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of Patch is 40.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113846.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..7c950623f77f48 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,11 +16,11 @@
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
-
namespace mlir {
namespace affine {
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..f53b5d97a7156a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1084,17 +1084,32 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
```
}];
- let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+ let arguments = (ins Index:$linear_index,
+ Variadic<Index>:$dynamic_basis,
+ DenseI64ArrayAttr:$static_basis);
let results = (outs Variadic<Index>:$multi_index);
let assemblyFormat = [{
- $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index)
+ $linear_index `into` ` `
+ custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+ attr-dict `:` type($multi_index)
}];
let builders = [
- OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+ OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
];
+ let extraClassDeclaration = [{
+ /// Return a vector with all the static and dynamic basis values.
+ SmallVector<OpFoldResult> getMixedBasis() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ }
+
+ }];
+
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..d2cfbaa85a60ef 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,6 +311,9 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+ Value linearIndex,
+ ArrayRef<OpFoldResult> basis);
// Generate IR that extracts the linear index from a multi-index according to
// a basis/shape.
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values,
+ ArrayRef<int64_t> integers,
+ AsmParser::Delimiter delimiter) {
+ return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+ delimiter);
+}
inline void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers,
+ AsmParser::Delimiter delimiter) {
+ DenseBoolArrayAttr scalableVals = {};
+ return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+ delimiter);
+}
inline ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..f384f454bc4726 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
regions);
- inferredReturnTypes.assign(adaptor.getBasis().size(),
+ inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
IndexType::get(context));
return success();
}
-void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
+ Value linearIndex, ValueRange basis) {
+ SmallVector<Value> dynamicBasis;
+ SmallVector<int64_t> staticBasis;
+ dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+ staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
Value linearIndex,
ArrayRef<OpFoldResult> basis) {
- result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
- result.addOperands(linearIndex);
- SmallVector<Value> basisValues =
- llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
- std::optional<int64_t> staticDim = getConstantIntValue(ofr);
- if (staticDim.has_value())
- return builder.create<arith::ConstantIndexOp>(result.location,
- *staticDim);
- return llvm::dyn_cast_if_present<Value>(ofr);
- });
- result.addOperands(basisValues);
+ SmallVector<Value> dynamicBasis;
+ SmallVector<int64_t> staticBasis;
+ dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
+ Value linearIndex,
+ ArrayRef<int64_t> basis) {
+ build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
}
LogicalResult AffineDelinearizeIndexOp::verify() {
- if (getBasis().empty())
+ if (getStaticBasis().empty())
return emitOpError("basis should not be empty");
- if (getNumResults() != getBasis().size())
+ if (getNumResults() != getStaticBasis().size())
return emitOpError("should return an index for each basis element");
+ auto dynamicMarkersCount =
+ llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+ if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+ return emitOpError(
+ "mismatch between dynamic and static basis (kDynamic marker but no "
+ "corresponding dynamic basis entry) -- this can only happen due to an "
+ "incorrect fold/rewrite");
return success();
}
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
- SmallVector<Value> newOperands;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
- if (matchPattern(basis, m_One()))
+ SmallVector<OpFoldResult> newOperands;
+ for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
+ std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ if (basisVal && *basisVal == 1)
replacements[index] = getZero();
else
newOperands.push_back(basis);
}
- if (newOperands.size() == delinearizeOp.getBasis().size())
+ if (newOperands.size() == delinearizeOp.getStaticBasis().size())
return failure();
if (!newOperands.empty()) {
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
PatternRewriter &rewriter) const override {
- auto basis = delinearizeOp.getBasis();
- if (basis.size() != 1)
+ if (delinearizeOp.getStaticBasis().size() != 1)
return failure();
+ auto basis = delinearizeOp.getMixedBasis();
// Check that the `linear_index` is an induction variable.
auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop
// Check that the upper-bound is the basis.
auto upperBounds = loopLikeOp.getLoopUpperBounds();
if (!upperBounds || upperBounds->size() != 1 ||
- upperBounds->front() != getAsOpFoldResult(basis.front())) {
+ upperBounds->front() != basis.front()) {
return rewriter.notifyMatchFailure(delinearizeOp,
"`basis` is not upper bound");
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..d76968d3a71520 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,9 +35,8 @@ struct LowerDelinearizeIndexOps
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- FailureOr<SmallVector<Value>> multiIndex =
- delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
- llvm::to_vector(op.getBasis()));
+ FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
+ rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
if (failed(multiIndex))
return failure();
rewriter.replaceOp(op, *multiIndex);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..e3b5d26e0ec3c3 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
return result;
}
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+ ArrayRef<OpFoldResult> set) {
+ if (set.empty())
+ return failure();
+ OpFoldResult result = set[0];
+ AffineExpr s0, s1;
+ bindSymbols(b.getContext(), s0, s1);
+ for (unsigned i = 1, e = set.size(); i < e; i++)
+ result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
+ return result;
+}
+
FailureOr<SmallVector<Value>>
mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
ArrayRef<Value> basis) {
@@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
return results;
}
+FailureOr<SmallVector<Value>>
+mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
+ ArrayRef<OpFoldResult> basis) {
+ unsigned numDims = basis.size();
+
+ SmallVector<Value> divisors;
+ for (unsigned i = 1; i < numDims; i++) {
+ ArrayRef<OpFoldResult> slice = basis.drop_front(i);
+ FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+ if (failed(prod))
+ return failure();
+ divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+ }
+
+ SmallVector<Value> results;
+ results.reserve(divisors.size() + 1);
+ Value residual = linearIndex;
+ for (Value divisor : divisors) {
+ DivModValue divMod = getDivMod(b, loc, residual, divisor);
+ results.push_back(divMod.quotient);
+ residual = divMod.remainder;
+ }
+ results.push_back(residual);
+ return results;
+}
+
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..298e82df4f4cea 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -931,53 +931,48 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
///////////////////////////////////////////////////////////////////////
func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
- %b0 = arith.constant 16 : index
- %b1 = arith.constant 224 : index
- %b2 = arith.constant 224 : index
- %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+ %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
// CHECK-LABEL: func.func @test_dilinearize_index(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 50176 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_7:.*]] = arith.constant -1 : index
-// CHECK: %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index
-// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index
-// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index
-// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index
-// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
-// CHECK: %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index
-// CHECK: %[[VAL_14:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index
-// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
-// CHECK: %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index
-// CHECK: %[[VAL_20:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index
-// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
-// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index
-// CHECK: %[[VAL_26:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_27:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_28:.*]] = arith.constant -1 : index
-// CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index
-// CHECK: %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index
-// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index
-// CHECK: %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index
-// CHECK: %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index
-// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index
-// CHECK: %[[VAL_35:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index
-// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index
-// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index
-// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
-// CHECK: return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
+// CHECK: %[[VAL_3:.*]] = arith.constant 50176 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant -1 : index
+// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
+// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
+// CHECK: %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
+// CHECK: %[[VAL_12:.*]] = arith.constant 50176 : index
+// CHECK: %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
+// CHECK: %[[VAL_18:.*]] = arith.constant 50176 : index
+// CHECK: %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
+// CHECK: %[[VAL_20:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
+// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
+// CHECK: %[[VAL_24:.*]] = arith.constant 224 : index
+// CHECK: %[[VAL_25:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_26:.*]] = arith.constant -1 : index
+// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
+// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
+// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
+// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
+// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
+// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
+...
[truncated]
|
@llvm/pr-subscribers-mlir-tensor Author: Krzysztof Drewniak (krzysz00) ChangesThis commit makes This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of Patch is 40.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113846.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..7c950623f77f48 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,11 +16,11 @@
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
-
namespace mlir {
namespace affine {
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..f53b5d97a7156a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1084,17 +1084,32 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
```
}];
- let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+ let arguments = (ins Index:$linear_index,
+ Variadic<Index>:$dynamic_basis,
+ DenseI64ArrayAttr:$static_basis);
let results = (outs Variadic<Index>:$multi_index);
let assemblyFormat = [{
- $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index)
+ $linear_index `into` ` `
+ custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+ attr-dict `:` type($multi_index)
}];
let builders = [
- OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+ OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
];
+ let extraClassDeclaration = [{
+ /// Return a vector with all the static and dynamic basis values.
+ SmallVector<OpFoldResult> getMixedBasis() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ }
+
+ }];
+
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..d2cfbaa85a60ef 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,6 +311,9 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+ Value linearIndex,
+ ArrayRef<OpFoldResult> basis);
// Generate IR that extracts the linear index from a multi-index according to
// a basis/shape.
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values,
+ ArrayRef<int64_t> integers,
+ AsmParser::Delimiter delimiter) {
+ return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+ delimiter);
+}
inline void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers,
+ AsmParser::Delimiter delimiter) {
+ DenseBoolArrayAttr scalableVals = {};
+ return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+ delimiter);
+}
inline ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..f384f454bc4726 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
regions);
- inferredReturnTypes.assign(adaptor.getBasis().size(),
+ inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
IndexType::get(context));
return success();
}
-void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
+ Value linearIndex, ValueRange basis) {
+ SmallVector<Value> dynamicBasis;
+ SmallVector<int64_t> staticBasis;
+ dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+ staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
Value linearIndex,
ArrayRef<OpFoldResult> basis) {
- result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
- result.addOperands(linearIndex);
- SmallVector<Value> basisValues =
- llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
- std::optional<int64_t> staticDim = getConstantIntValue(ofr);
- if (staticDim.has_value())
- return builder.create<arith::ConstantIndexOp>(result.location,
- *staticDim);
- return llvm::dyn_cast_if_present<Value>(ofr);
- });
- result.addOperands(basisValues);
+ SmallVector<Value> dynamicBasis;
+ SmallVector<int64_t> staticBasis;
+ dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
+ Value linearIndex,
+ ArrayRef<int64_t> basis) {
+ build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
}
LogicalResult AffineDelinearizeIndexOp::verify() {
- if (getBasis().empty())
+ if (getStaticBasis().empty())
return emitOpError("basis should not be empty");
- if (getNumResults() != getBasis().size())
+ if (getNumResults() != getStaticBasis().size())
return emitOpError("should return an index for each basis element");
+ auto dynamicMarkersCount =
+ llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+ if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+ return emitOpError(
+ "mismatch between dynamic and static basis (kDynamic marker but no "
+ "corresponding dynamic basis entry) -- this can only happen due to an "
+ "incorrect fold/rewrite");
return success();
}
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
- SmallVector<Value> newOperands;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
- if (matchPattern(basis, m_One()))
+ SmallVector<OpFoldResult> newOperands;
+ for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
+ std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ if (basisVal && *basisVal == 1)
replacements[index] = getZero();
else
newOperands.push_back(basis);
}
- if (newOperands.size() == delinearizeOp.getBasis().size())
+ if (newOperands.size() == delinearizeOp.getStaticBasis().size())
return failure();
if (!newOperands.empty()) {
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
PatternRewriter &rewriter) const override {
- auto basis = delinearizeOp.getBasis();
- if (basis.size() != 1)
+ if (delinearizeOp.getStaticBasis().size() != 1)
return failure();
+ auto basis = delinearizeOp.getMixedBasis();
// Check that the `linear_index` is an induction variable.
auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop
// Check that the upper-bound is the basis.
auto upperBounds = loopLikeOp.getLoopUpperBounds();
if (!upperBounds || upperBounds->size() != 1 ||
- upperBounds->front() != getAsOpFoldResult(basis.front())) {
+ upperBounds->front() != basis.front()) {
return rewriter.notifyMatchFailure(delinearizeOp,
"`basis` is not upper bound");
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..d76968d3a71520 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,9 +35,8 @@ struct LowerDelinearizeIndexOps
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- FailureOr<SmallVector<Value>> multiIndex =
- delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
- llvm::to_vector(op.getBasis()));
+ FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
+ rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
if (failed(multiIndex))
return failure();
rewriter.replaceOp(op, *multiIndex);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..e3b5d26e0ec3c3 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
return result;
}
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+ ArrayRef<OpFoldResult> set) {
+ if (set.empty())
+ return failure();
+ OpFoldResult result = set[0];
+ AffineExpr s0, s1;
+ bindSymbols(b.getContext(), s0, s1);
+ for (unsigned i = 1, e = set.size(); i < e; i++)
+ result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
+ return result;
+}
+
FailureOr<SmallVector<Value>>
mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
ArrayRef<Value> basis) {
@@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
return results;
}
+FailureOr<SmallVector<Value>>
+mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
+ ArrayRef<OpFoldResult> basis) {
+ unsigned numDims = basis.size();
+
+ SmallVector<Value> divisors;
+ for (unsigned i = 1; i < numDims; i++) {
+ ArrayRef<OpFoldResult> slice = basis.drop_front(i);
+ FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+ if (failed(prod))
+ return failure();
+ divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+ }
+
+ SmallVector<Value> results;
+ results.reserve(divisors.size() + 1);
+ Value residual = linearIndex;
+ for (Value divisor : divisors) {
+ DivModValue divMod = getDivMod(b, loc, residual, divisor);
+ results.push_back(divMod.quotient);
+ residual = divMod.remainder;
+ }
+ results.push_back(residual);
+ return results;
+}
+
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..298e82df4f4cea 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -931,53 +931,48 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
///////////////////////////////////////////////////////////////////////
func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
- %b0 = arith.constant 16 : index
- %b1 = arith.constant 224 : index
- %b2 = arith.constant 224 : index
- %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+ %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
// CHECK-LABEL: func.func @test_dilinearize_index(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 50176 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_7:.*]] = arith.constant -1 : index
-// CHECK: %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index
-// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index
-// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index
-// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index
-// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
-// CHECK: %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index
-// CHECK: %[[VAL_14:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index
-// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
-// CHECK: %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index
-// CHECK: %[[VAL_20:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index
-// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
-// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index
-// CHECK: %[[VAL_26:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_27:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_28:.*]] = arith.constant -1 : index
-// CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index
-// CHECK: %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index
-// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index
-// CHECK: %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index
-// CHECK: %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index
-// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index
-// CHECK: %[[VAL_35:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index
-// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index
-// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index
-// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
-// CHECK: return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
+// CHECK: %[[VAL_3:.*]] = arith.constant 50176 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant -1 : index
+// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
+// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
+// CHECK: %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
+// CHECK: %[[VAL_12:.*]] = arith.constant 50176 : index
+// CHECK: %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
+// CHECK: %[[VAL_18:.*]] = arith.constant 50176 : index
+// CHECK: %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
+// CHECK: %[[VAL_20:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
+// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
+// CHECK: %[[VAL_24:.*]] = arith.constant 224 : index
+// CHECK: %[[VAL_25:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_26:.*]] = arith.constant -1 : index
+// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
+// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
+// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
+// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
+// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
+// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
+...
[truncated]
|
"corresponding dynamic basis entry) -- this can only happen due to an " | ||
"incorrect fold/rewrite"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not construct such an op using Operation::create
(rather than builder)? If we cannot, this should be an assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The builder construction is along the lines of b.create<affine::DelinearizeIndexOp>(loc, index, {v1, v2}, {2, 5, 7})
, which has too many dynamic basis entries, or alternatively (index, {}, {2, kDynamic, 7})
which has too few
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc, | |||
return result; | |||
} | |||
|
|||
static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
SmallVector<Value> divisors; | ||
for (unsigned i = 1; i < numDims; i++) { | ||
ArrayRef<OpFoldResult> slice = basis.drop_front(i); | ||
FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice); | ||
if (failed(prod)) | ||
return failure(); | ||
divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be more efficient to do a scan and collect intermediate products than rerun the product every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done (and applied to the existing function), thanks!
`affine.linearize_index` is the inverse of `affine.delinearize_index` and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices. This commit introduces `affine.linearize_index` and one simple canonicalization for it. There are plans to add `affine.linearize_index` and `affine.delinearize_index` pair canonicalizations, but we are saving those for a followup PR (especially since having llvm#113846 landed would make them nicer). Note while `affine` may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation LGTM
Co-authored-by: Jakub Kuderski <[email protected]>
- Fix Python Bindings LIT based on upstream update - Reference: llvm/llvm-project#113846
`affine.linearize_index` is the inverse of `affine.delinearize_index` and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices. This commit introduces `affine.linearize_index` and one simple canonicalization for it. There are plans to add `affine.linearize_index` and `affine.delinearize_index` pair canonicalizations, but we are saving those for a followup PR (especially since having #113846 landed would make them nicer). Note while `affine` may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location. --------- Co-authored-by: Jakub Kuderski <[email protected]>
…13846) This commit makes `affine.delinealize` join other indexing operators, like `vector.extract`, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the `basis` (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants. This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of `OpFoldResult`s and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax). --------- Co-authored-by: Jakub Kuderski <[email protected]>
…13846) This commit makes `affine.delinealize` join other indexing operators, like `vector.extract`, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, the `basis` (the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants. This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of `OpFoldResult`s and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax). --------- Co-authored-by: Jakub Kuderski <[email protected]>
`affine.linearize_index` is the inverse of `affine.delinearize_index` and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices. This commit introduces `affine.linearize_index` and one simple canonicalization for it. There are plans to add `affine.linearize_index` and `affine.delinearize_index` pair canonicalizations, but we are saving those for a followup PR (especially since having llvm#113846 landed would make them nicer). Note while `affine` may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location. --------- Co-authored-by: Jakub Kuderski <[email protected]>
This commit makes
affine.delinealize
join other indexing operators, likevector.extract
, which store a mixed static/dynamic set of sizes, offsets, or such. In this case, thebasis
(the set of values that will be used to decompose the linear index) is now stored as an array of index attributes where the basis is statically known, eliminating the need to cretae constants.This commit also adds copies of the delinearize utility in the affine dialect to allow it to take an array of
OpFoldResult
s and extends te DynamicIndexList parser/printer to allow specifying the delimiters in tablegen (this is needed to avoid breaking existing syntax).