Skip to content

Commit

Permalink
[Codegen] Support dynamic/scalable sizes when folding insert_slice in…
Browse files Browse the repository at this point in the history
…to xfer_write (iree-org#17963)

This enables further optimizations which are currently missed when
targeting scalable vectors.

---------

Signed-off-by: Benjamin Maxwell <[email protected]>
  • Loading branch information
MacDue committed Aug 6, 2024
1 parent 4ff771a commit 98a9ca2
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/Utils.h"

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -117,6 +119,63 @@ class FoldExtractSliceIntoTransferRead final
}
};

/// Returns true if `writeOp` fully overwrites its destination.
///
/// Example:
///
/// ```
/// vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]}
/// : vector<4x5xf32>, tensor<4x5xf32>
/// ```
///
/// This is an easy case, `vector<4x5xf32>` fully-overwrites `tensor<4x5xf32>`
/// as the vector is the same size as the tensor. This check also supports
/// dynamic tensors, where it resolves the tensor sizes via value-bounds
/// analysis, and then checks if the vector type fully overwrites the tensor.
static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) {
if (writeOp.hasOutOfBoundsDim())
return false;
if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank())
return false;
if (writeOp.getMask())
return false;

std::optional<iree_compiler::VscaleRange> vscaleRange;
auto vecType = writeOp.getVectorType();
if (vecType.isScalable()) {
auto targetAttr =
iree_compiler::IREE::HAL::ExecutableTargetAttr::lookup(writeOp);
vscaleRange = iree_compiler::getDefaultVscaleRange(targetAttr);
}

Value dest = writeOp.getSource();
ArrayRef<int64_t> destShape = writeOp.getShapedType().getShape();

// Attempts to resolve the size of a dim within the destination.
auto resolveDestinationDimSize =
[&](unsigned dimIndex) -> FailureOr<iree_compiler::DimBoundSize> {
auto size = destShape[dimIndex];
// Fixed-size dimensions are simply included in the shape.
if (size != ShapedType::kDynamic)
return iree_compiler::DimBoundSize{size};
// (Attempt to) resolve dynamic dimensions via value-bounds analysis.
return iree_compiler::computeDimUpperBound(dest, dimIndex, vscaleRange);
};

ArrayRef<int64_t> vecShape = vecType.getShape();
ArrayRef<bool> vecScalableFlags = vecType.getScalableDims();
for (unsigned d = 0, e = destShape.size(); d < e; ++d) {
auto dimSize = resolveDestinationDimSize(d);
if (failed(dimSize))
return false;
if (dimSize->scalable && !vecScalableFlags[d])
return false;
if (vecShape[d] != dimSize->baseSize)
return false;
}
return true;
}

/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
/// could directly write to the insert_slice's destination. E.g.:
///
Expand Down Expand Up @@ -150,20 +209,12 @@ class FoldInsertSliceIntoTransferWrite final
// TODO: support 0-d corner case.
if (xferOp.getTransferRank() == 0)
return failure();

if (xferOp.hasOutOfBoundsDim())
return failure();
if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
return failure();
if (xferOp.getMask())
if (!xferOp.getPermutationMap().isIdentity())
return failure();
// Fold only if the TransferWriteOp completely overwrites the `source` with
// a vector. I.e., the result of the TransferWriteOp is a new tensor whose
// content is the data of the vector.
if (!llvm::equal(xferOp.getVectorType().getShape(),
xferOp.getShapedType().getShape()))
return failure();
if (!xferOp.getPermutationMap().isIdentity())
if (!isDestinationFullyOverwritten(xferOp))
return failure();

// Bail on illegal rank-reduction: we need to check that the rank-reduced
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,123 @@ func.func @fold_extract_slice_consumer_into_xfer_write_3(%arg0: vector<1x64x128x

// -----

func.func @fold_insert_slice_into_transfer_write_static(%v: vector<4x5xf32>, %t1: tensor<4x5xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} : vector<4x5xf32>, tensor<4x5xf32>
%1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_static
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[T1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[T2:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[B:[a-zA-Z0-9]+]]
// CHECK-NEXT: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x5xf32>, tensor<?x?xf32>
// CHECK-NEXT: return %[[WRITE]]

// -----

#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}>

func.func @fold_insert_slice_into_transfer_write_scalable(%v: vector<4x[4]xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32>
attributes {hal.executable.target = #aarch64_sve}
{
%vscale = vector.vscale
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c4_vscale = arith.muli %c4, %vscale : index
%extract_slice = tensor.extract_slice %t1[0, 0] [4, %c4_vscale] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
%0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x[4]xf32>, tensor<4x?xf32>
%1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %c4_vscale] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_scalable
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[T1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[T2:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[B:[a-zA-Z0-9]+]]
// CHECK-NEXT: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x[4]xf32>, tensor<?x?xf32>
// CHECK-NEXT: return %[[WRITE]]

// -----

func.func @fold_insert_slice_into_transfer_write_dynamic(%v: vector<4x8xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index, %size: index) -> tensor<?x?xf32>
{
%c0 = arith.constant 0 : index
%slice_size = affine.min affine_map<(d0) -> (d0, 8)>(%size)
%extract_slice = tensor.extract_slice %t1[0, 0] [4, %slice_size] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
%0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xf32>, tensor<4x?xf32>
%1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %slice_size] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @fold_insert_slice_into_transfer_write_dynamic
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[T1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[T2:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[B:[a-zA-Z0-9]+]]
// CHECK-NEXT: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[T2]][%[[A]], %[[B]]] {in_bounds = [true, true]} : vector<4x8xf32>, tensor<?x?xf32>
// CHECK-NEXT: return %[[WRITE]]

// -----

func.func @negative_fold_insert_slice_into_transfer_write_static(%v: vector<3x5xf32>, %t1: tensor<4x5xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf32>, tensor<4x5xf32>
%1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_static
// CHECK: %[[WRITE:.*]] = vector.transfer_write
// CHECK: tensor.insert_slice %[[WRITE]]

// -----

#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}>

func.func @negative_fold_insert_slice_into_transfer_write_scalable(%v: vector<4x[2]xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index) -> tensor<?x?xf32>
attributes {hal.executable.target = #aarch64_sve}
{
%vscale = vector.vscale
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c4_vscale = arith.muli %c4, %vscale : index
%extract_slice = tensor.extract_slice %t1[0, 0] [4, %c4_vscale] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
%0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x[2]xf32>, tensor<4x?xf32>
%1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %c4_vscale] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_scalable
// CHECK: %[[WRITE:.*]] = vector.transfer_write
// CHECK: tensor.insert_slice %[[WRITE]]

// -----

func.func @negative_fold_insert_slice_into_transfer_write_dynamic(%v: vector<4x7xf32>, %t1: tensor<?x?xf32>, %t2: tensor<?x?xf32>, %a: index, %b: index, %size: index) -> tensor<?x?xf32>
{
%c0 = arith.constant 0 : index
%slice_size = affine.min affine_map<(d0) -> (d0, 8)>(%size)
%extract_slice = tensor.extract_slice %t1[0, 0] [4, %slice_size] [1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
%0 = vector.transfer_write %v, %extract_slice[%c0, %c0] {in_bounds = [true, true]} : vector<4x7xf32>, tensor<4x?xf32>
%1 = tensor.insert_slice %0 into %t2[%a, %b] [4, %slice_size] [1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @negative_fold_insert_slice_into_transfer_write_dynamic
// CHECK: %[[WRITE:.*]] = vector.transfer_write
// CHECK: tensor.insert_slice %[[WRITE]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>

#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>
Expand Down

0 comments on commit 98a9ca2

Please sign in to comment.