diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp index 4265919158c2..b84f6cf24e8b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp @@ -17,6 +17,21 @@ namespace mlir::iree_compiler { namespace { +struct ConvertLinalgCopyToMemrefCopy final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::CopyOp copyOp, + PatternRewriter &rewriter) const override { + if (copyOp.hasPureTensorSemantics()) { + return failure(); + } + rewriter.create(copyOp.getLoc(), + copyOp.getDpsInputOperand(0)->get(), + copyOp.getDpsInitOperand(0)->get()); + rewriter.eraseOp(copyOp); + return success(); + } +}; + struct VectorizeMemrefCopyPass final : impl::VectorizeMemrefCopyPassBase { void getDependentDialects(DialectRegistry ®istry) const override { @@ -28,6 +43,7 @@ struct VectorizeMemrefCopyPass final RewritePatternSet patterns(ctx); patterns.add(&getContext()); + patterns.add(&getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vectorize_memref_copy.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vectorize_memref_copy.mlir index 9e9a4baec70a..ae0948276d41 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/vectorize_memref_copy.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/vectorize_memref_copy.mlir @@ -10,3 +10,16 @@ func.func @memref_copy(%source: memref<2x2xf32>, %dest: memref<2x2xf32>) { // CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x2xf32> // CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE]] // CHECK: vector.transfer_write %[[RD]], %[[DEST]] + +// ----- + +func.func @linalg_copy(%source: memref<2x2xf32>, %dest: memref<2x2xf32>) { + linalg.copy ins(%source : memref<2x2xf32>) outs(%dest : memref<2x2xf32>) + return +} + +// CHECK-LABEL: func.func @linalg_copy +// CHECK-SAME: %[[SOURCE:[A-Za-z0-9]+]]: memref<2x2xf32> +// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x2xf32> +// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE]] +// CHECK: vector.transfer_write %[[RD]], %[[DEST]]