diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp index 7f67b76a1413..be6fd7a2a265 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp @@ -5,6 +5,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -117,6 +119,63 @@ class FoldExtractSliceIntoTransferRead final } }; +/// Returns true if `writeOp` fully overwrites its destination. +/// +/// Example: +/// +/// ``` +/// vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} +/// : vector<4x5xf32>, tensor<4x5xf32> +/// ``` +/// +/// This is an easy case, `vector<4x5xf32>` fully-overwrites `tensor<4x5xf32>` +/// as the vector is the same size as the tensor. This check also supports +/// dynamic tensors, where it resolves the tensor sizes via value-bounds +/// analysis, and then checks if the vector type fully overwrites the tensor. +static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { + if (writeOp.hasOutOfBoundsDim()) + return false; + if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank()) + return false; + if (writeOp.getMask()) + return false; + + std::optional vscaleRange; + auto vecType = writeOp.getVectorType(); + if (vecType.isScalable()) { + auto targetAttr = + iree_compiler::IREE::HAL::ExecutableTargetAttr::lookup(writeOp); + vscaleRange = iree_compiler::getDefaultVscaleRange(targetAttr); + } + + Value dest = writeOp.getSource(); + ArrayRef destShape = writeOp.getShapedType().getShape(); + + // Attempts to resolve the size of a dim within the destination. + auto resolveDestinationDimSize = + [&](unsigned dimIndex) -> FailureOr { + auto size = destShape[dimIndex]; + // Fixed-size dimensions are simply included in the shape. + if (size != ShapedType::kDynamic) + return iree_compiler::DimBoundSize{size}; + // (Attempt to) resolve dynamic dimensions via value-bounds analysis. + return iree_compiler::computeDimUpperBound(dest, dimIndex, vscaleRange); + }; + + ArrayRef vecShape = vecType.getShape(); + ArrayRef vecScalableFlags = vecType.getScalableDims(); + for (unsigned d = 0, e = destShape.size(); d < e; ++d) { + auto dimSize = resolveDestinationDimSize(d); + if (failed(dimSize)) + return false; + if (dimSize->scalable && !vecScalableFlags[d]) + return false; + if (vecShape[d] != dimSize->baseSize) + return false; + } + return true; +} + /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write /// could directly write to the insert_slice's destination. E.g.: /// @@ -150,20 +209,12 @@ class FoldInsertSliceIntoTransferWrite final // TODO: support 0-d corner case. if (xferOp.getTransferRank() == 0) return failure(); - - if (xferOp.hasOutOfBoundsDim()) - return failure(); - if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) - return failure(); - if (xferOp.getMask()) + if (!xferOp.getPermutationMap().isIdentity()) return failure(); // Fold only if the TransferWriteOp completely overwrites the `source` with // a vector. I.e., the result of the TransferWriteOp is a new tensor whose // content is the data of the vector. - if (!llvm::equal(xferOp.getVectorType().getShape(), - xferOp.getShapedType().getShape())) - return failure(); - if (!xferOp.getPermutationMap().isIdentity()) + if (!isDestinationFullyOverwritten(xferOp)) return failure(); // Bail on illegal rank-reduction: we need to check that the rank-reduced diff --git a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir index b07b2b533527..125545300857 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir @@ -63,6 +63,115 @@ func.func @fold_extract_slice_consumer_into_xfer_write_3(%arg0: vector<1x64x128x // ----- +func.func @fold_insert_slice_into_transfer_write_static(%v: vector<4x5xf32>, %t1: tensor<4x5xf32>, %t2: tensor, %a: index, %b: index) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} : vector<4x5xf32>, tensor<4x5xf32> + %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] : tensor<4x5xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_static +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[T1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[T2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[B:[a-zA-Z0-9]+]] +// CHECK-NEXT: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x5xf32>, tensor +// CHECK-NEXT: return %[[WRITE]] + +// ----- + +#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}> + +func.func @fold_insert_slice_into_transfer_write_scalable(%v: vector<4x[4]xf32>, %t1: tensor, %t2: tensor, %a: index, %b: index) -> tensor + attributes {hal.executable.target = #aarch64_sve} +{ + %vscale = vector.vscale + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c4_vscale = arith.muli %c4, %vscale : index + %extract_slice = tensor.extract_slice %t1[0, 0] [4, %c4_vscale] [1, 1] : tensor to tensor<4x?xf32> + %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x[4]xf32>, tensor<4x?xf32> + %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %c4_vscale] [1, 1] : tensor<4x?xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_scalable +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[T1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[T2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[B:[a-zA-Z0-9]+]] +// CHECK-NEXT: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x[4]xf32>, tensor +// CHECK-NEXT: return %[[WRITE]] + +// ----- + +func.func @fold_insert_slice_into_transfer_write_dynamic(%v: vector<4x8xf32>, %t1: tensor, %t2: tensor, %a: index, %b: index, %size: index) -> tensor +{ + %c0 = arith.constant 0 : index + %slice_size = affine.min affine_map<(d0) -> (d0, 8)>(%size) + %extract_slice = tensor.extract_slice %t1[0, 0] [4, %slice_size] [1, 1] : tensor to tensor<4x?xf32> + %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf32>, tensor<4x?xf32> + %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %slice_size] [1, 1] : tensor<4x?xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_dynamic +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[T1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[T2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[A:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[B:[a-zA-Z0-9]+]] +// CHECK-NEXT: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x8xf32>, tensor +// CHECK-NEXT: return %[[WRITE]] + +// ----- + +func.func @negative_fold_insert_slice_into_transfer_write_static(%v: vector<3x5xf32>, %t1: tensor<4x5xf32>, %t2: tensor, %a: index, %b: index) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf32>, tensor<4x5xf32> + %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] : tensor<4x5xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_static +// CHECK: %[[WRITE:.*]] = vector.transfer_write +// CHECK: tensor.insert_slice %[[WRITE]] + +// ----- + +#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}> + +func.func @negative_fold_insert_slice_into_transfer_write_scalable(%v: vector<4x[2]xf32>, %t1: tensor, %t2: tensor, %a: index, %b: index) -> tensor + attributes {hal.executable.target = #aarch64_sve} +{ + %vscale = vector.vscale + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c4_vscale = arith.muli %c4, %vscale : index + %extract_slice = tensor.extract_slice %t1[0, 0] [4, %c4_vscale] [1, 1] : tensor to tensor<4x?xf32> + %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x[2]xf32>, tensor<4x?xf32> + %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %c4_vscale] [1, 1] : tensor<4x?xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_scalable +// CHECK: %[[WRITE:.*]] = vector.transfer_write +// CHECK: tensor.insert_slice %[[WRITE]] + +// ----- + +func.func @negative_fold_insert_slice_into_transfer_write_dynamic(%v: vector<4x7xf32>, %t1: tensor, %t2: tensor, %a: index, %b: index, %size: index) -> tensor +{ + %c0 = arith.constant 0 : index + %slice_size = affine.min affine_map<(d0) -> (d0, 8)>(%size) + %extract_slice = tensor.extract_slice %t1[0, 0] [4, %slice_size] [1, 1] : tensor to tensor<4x?xf32> + %0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x7xf32>, tensor<4x?xf32> + %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %slice_size] [1, 1] : tensor<4x?xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_dynamic +// CHECK: %[[WRITE:.*]] = vector.transfer_write +// CHECK: tensor.insert_slice %[[WRITE]] + +// ----- + #pipeline_layout = #hal.pipeline.layout, @@ -70,6 +179,7 @@ func.func @fold_extract_slice_consumer_into_xfer_write_3(%arg0: vector<1x64x128x #hal.descriptor_set.binding<2, storage_buffer> ]> ]> + #map = affine_map<()[s0] -> (s0 * 64)> #map1 = affine_map<()[s0] -> (s0 * 128)> #map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>