Skip to content

Commit

Permalink
[Codegen] Allow vectorizing linalg.copy ops on memrefs (#18672)
Browse files Browse the repository at this point in the history
The upstream patterns for copy vectorization only support memref.copy.
This adds a pattern to first convert linalg.copy to memref.copy so the
vectorization pattern can kick in for the VectorizeMemrefCopyPass.
  • Loading branch information
qedawkins authored Oct 2, 2024
1 parent 903ab0a commit 206c1f2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ namespace mlir::iree_compiler {

namespace {

struct ConvertLinalgCopyToMemrefCopy final : OpRewritePattern<linalg::CopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
PatternRewriter &rewriter) const override {
if (copyOp.hasPureTensorSemantics()) {
return failure();
}
rewriter.create<memref::CopyOp>(copyOp.getLoc(),
copyOp.getDpsInputOperand(0)->get(),
copyOp.getDpsInitOperand(0)->get());
rewriter.eraseOp(copyOp);
return success();
}
};

struct VectorizeMemrefCopyPass final
: impl::VectorizeMemrefCopyPassBase<VectorizeMemrefCopyPass> {
void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -28,6 +43,7 @@ struct VectorizeMemrefCopyPass final

RewritePatternSet patterns(ctx);
patterns.add<linalg::CopyVectorizationPattern>(&getContext());
patterns.add<ConvertLinalgCopyToMemrefCopy>(&getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,16 @@ func.func @memref_copy(%source: memref<2x2xf32>, %dest: memref<2x2xf32>) {
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x2xf32>
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE]]
// CHECK: vector.transfer_write %[[RD]], %[[DEST]]

// -----

func.func @linalg_copy(%source: memref<2x2xf32>, %dest: memref<2x2xf32>) {
linalg.copy ins(%source : memref<2x2xf32>) outs(%dest : memref<2x2xf32>)
return
}

// CHECK-LABEL: func.func @linalg_copy
// CHECK-SAME: %[[SOURCE:[A-Za-z0-9]+]]: memref<2x2xf32>
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x2xf32>
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE]]
// CHECK: vector.transfer_write %[[RD]], %[[DEST]]

0 comments on commit 206c1f2

Please sign in to comment.