Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Fix scalable vectorisation of tensor.extract #100325

Conversation

banach-space
Copy link
Contributor

This PR fixes one very specific aspect of vectorising tensor.extract
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating
vector::ShapeCastOp.

BEFORE:

vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<4xindex>

AFTER:

vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<[4]xindex>

This particular ShapeCastOp is created when generating an index for
vector.transfer_read operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).

The following test is updated to demonstrate the use of
vector.shape_cast by the vectoriser:

  • @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test with scalable vectors is also added.

(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").

This PR fixes one very specific aspect of vectorising `tensor.extract`
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating
`vector::ShapeCastOp`.

BEFORE:
```mlir
vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<4xindex>
```

AFTER:
```mlir
vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<[4]xindex>
```

This particular ShapeCastOp is created when generating an index for
`vector.transfer_read` operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).

The following test is updated to demonstrate the use of
`vector.shape_cast` by the vectoriser:
  * @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test with scalable vectors is also added.

(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 24, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This PR fixes one very specific aspect of vectorising tensor.extract
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating
vector::ShapeCastOp.

BEFORE:

vector.shape_cast %idx_vec : vector&lt;1x1x[4]xindex&gt; to vector&lt;4xindex&gt;

AFTER:

vector.shape_cast %idx_vec : vector&lt;1x1x[4]xindex&gt; to vector&lt;[4]xindex&gt;

This particular ShapeCastOp is created when generating an index for
vector.transfer_read operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).

The following test is updated to demonstrate the use of
vector.shape_cast by the vectoriser:

  • @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test with scalable vectors is also added.

(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").


Full diff: https://github.com/llvm/llvm-project/pull/100325.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5-4)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+98-13)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7f7168eb86832..86a53377e81e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1069,19 +1069,20 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto resTrailingDim = resultType.getShape().back();
   auto zero = rewriter.create<arith::ConstantOp>(
       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
-    auto idx = bvm.lookup(extractOp.getIndices()[i]);
+    Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
       transferReadIdxs.push_back(idx);
       continue;
     }
 
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
-        bvm.lookup(extractOp.getIndices()[i]));
+        loc,
+        VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
+                        resultType.getScalableDims().back()),
+        idx);
     transferReadIdxs.push_back(
         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
   }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index f042753780013..964565620fd01 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -1,29 +1,52 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
 
-func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+    %src: tensor<80x16xf32>,
+    %output : tensor<1x3xf32>,
+    %idx: index) -> tensor<1x3xf32> {
+
   %c79 = arith.constant 79 : index
   %1 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]
-  } outs(%extracted_slice : tensor<1x3xf32>) {
+  } outs(%output : tensor<1x3xf32>) {
   ^bb0(%out: f32):
     %2 = linalg.index 1 : index
-    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
-    %extracted = tensor.extract %6[%c79, %3] : tensor<80x16xf32>
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
     linalg.yield %extracted : f32
   } -> tensor<1x3xf32>
   return %1 : tensor<1x3xf32>
 }
 
 // CHECK-LABEL:   func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_5]] : vector<1x4xi1>
-// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.broadcast {{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_12:.*]] = arith.addi {{.*}} : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
+// CHECK-SAME:      %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME:      %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+
+/// Create the mask
+// CHECK-DAG:       %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -33,7 +56,69 @@ module attributes {transform.with_named_sequence} {
    }
 }
 
- // -----
+// -----
+
+// Identical to the above, but with scalable vectors.
+
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+    %src: tensor<80x16xf32>,
+    %output : tensor<1x3xf32>,
+    %idx: index) -> tensor<1x3xf32> {
+
+  %c79 = arith.constant 79 : index
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%output : tensor<1x3xf32>) {
+  ^bb0(%out: f32):
+    %2 = linalg.index 1 : index
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<1x3xf32>
+
+  return %1 : tensor<1x3xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable
+// CHECK-SAME:      %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME:      %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+
+/// Create the mask
+// CHECK-DAG:       %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<1x3xf32> } : vector<1x[4]xi1> -> tensor<1x3xf32>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+     transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+     transform.yield
+   }
+}
+
+// -----
 
 func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c79 = arith.constant 79 : index

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 24, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This PR fixes one very specific aspect of vectorising tensor.extract
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating
vector::ShapeCastOp.

BEFORE:

vector.shape_cast %idx_vec : vector&lt;1x1x[4]xindex&gt; to vector&lt;4xindex&gt;

AFTER:

vector.shape_cast %idx_vec : vector&lt;1x1x[4]xindex&gt; to vector&lt;[4]xindex&gt;

This particular ShapeCastOp is created when generating an index for
vector.transfer_read operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).

The following test is updated to demonstrate the use of
vector.shape_cast by the vectoriser:

  • @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test with scalable vectors is also added.

(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").


Full diff: https://github.com/llvm/llvm-project/pull/100325.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5-4)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+98-13)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7f7168eb86832..86a53377e81e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1069,19 +1069,20 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto resTrailingDim = resultType.getShape().back();
   auto zero = rewriter.create<arith::ConstantOp>(
       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
-    auto idx = bvm.lookup(extractOp.getIndices()[i]);
+    Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
       transferReadIdxs.push_back(idx);
       continue;
     }
 
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
-        bvm.lookup(extractOp.getIndices()[i]));
+        loc,
+        VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
+                        resultType.getScalableDims().back()),
+        idx);
     transferReadIdxs.push_back(
         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
   }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index f042753780013..964565620fd01 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -1,29 +1,52 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
 
-func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+    %src: tensor<80x16xf32>,
+    %output : tensor<1x3xf32>,
+    %idx: index) -> tensor<1x3xf32> {
+
   %c79 = arith.constant 79 : index
   %1 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]
-  } outs(%extracted_slice : tensor<1x3xf32>) {
+  } outs(%output : tensor<1x3xf32>) {
   ^bb0(%out: f32):
     %2 = linalg.index 1 : index
-    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
-    %extracted = tensor.extract %6[%c79, %3] : tensor<80x16xf32>
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
     linalg.yield %extracted : f32
   } -> tensor<1x3xf32>
   return %1 : tensor<1x3xf32>
 }
 
 // CHECK-LABEL:   func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_5]] : vector<1x4xi1>
-// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.broadcast {{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_12:.*]] = arith.addi {{.*}} : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
+// CHECK-SAME:      %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME:      %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+
+/// Create the mask
+// CHECK-DAG:       %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -33,7 +56,69 @@ module attributes {transform.with_named_sequence} {
    }
 }
 
- // -----
+// -----
+
+// Identical to the above, but with scalable vectors.
+
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+    %src: tensor<80x16xf32>,
+    %output : tensor<1x3xf32>,
+    %idx: index) -> tensor<1x3xf32> {
+
+  %c79 = arith.constant 79 : index
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%output : tensor<1x3xf32>) {
+  ^bb0(%out: f32):
+    %2 = linalg.index 1 : index
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<1x3xf32>
+
+  return %1 : tensor<1x3xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable
+// CHECK-SAME:      %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME:      %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+
+/// Create the mask
+// CHECK-DAG:       %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<1x3xf32> } : vector<1x[4]xi1> -> tensor<1x3xf32>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+     transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+     transform.yield
+   }
+}
+
+// -----
 
 func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c79 = arith.constant 79 : index

banach-space added a commit to banach-space/llvm-project that referenced this pull request Jul 24, 2024
Updates the verifier for `vector.shape_cast` so that the following
incorrect cases are immediately rejected:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Seperately, here's a fix for the Linalg vectorizer to prevent the
vectorizer from generating such shape casts (*):
* llvm#100325

(*) Note, that's just one specific case that I've identified so far.
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@banach-space banach-space merged commit c7a3346 into llvm:main Jul 25, 2024
10 checks passed
banach-space added a commit that referenced this pull request Jul 25, 2024
Updates the verifier for `vector.shape_cast` so that incorrect cases
where "scalability" is dropped are immediately rejected. For example:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Also, as a separate PR, I've prepared a fix for the Linalg vectorizer to
avoid generating such shape casts (*):
* #100325

(*) Note, that's just one specific case that I've identified so far.
@banach-space banach-space deleted the andrzej/extend_tensor_extract_vectorisation branch July 25, 2024 11:40
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
This PR fixes one very specific aspect of vectorising `tensor.extract`
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating
`vector::ShapeCastOp`.

BEFORE:
```mlir
vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<4xindex>
```

AFTER:
```mlir
vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<[4]xindex>
```

This particular ShapeCastOp is created when generating an index for
`vector.transfer_read` operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).

The following test is updated to demonstrate the use of
`vector.shape_cast` by the vectoriser:
*
@masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test with scalable vectors is also added.

(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250742
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
Updates the verifier for `vector.shape_cast` so that incorrect cases
where "scalability" is dropped are immediately rejected. For example:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Also, as a separate PR, I've prepared a fix for the Linalg vectorizer to
avoid generating such shape casts (*):
* #100325

(*) Note, that's just one specific case that I've identified so far.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250589
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants