Skip to content

Commit

Permalink
[mlir][vector] Restrict vector.shape_cast (scalable vectors) (#100331)
Browse files Browse the repository at this point in the history
Updates the verifier for `vector.shape_cast` so that incorrect cases
where "scalability" is dropped are immediately rejected. For example:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Also, as a separate PR, I've prepared a fix for the Linalg vectorizer to
avoid generating such shape casts (*):
* #100325

(*) Note, that's just one specific case that I've identified so far.
  • Loading branch information
banach-space authored Jul 25, 2024
1 parent c7a3346 commit 98c73d5
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,11 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
return !llvm::is_contained(getScalableDims(), false);
}

/// Get the number of scalable dimensions.
int64_t getNumScalableDims() const {
return llvm::count(getScalableDims(), true);
}

/// Get or create a new VectorType with the same shape as `this` and an
/// element type of bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5238,6 +5238,16 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
if (!isValidShapeCast(resultShape, sourceShape))
return op->emitOpError("invalid shape cast");
}

// Check that (non-)scalability is preserved
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
return op->emitOpError("different number of scalable dims at source (")
<< sourceNScalableDims << ") and result (" << resultNScalableDims
<< ")";
sourceVectorType.getNumDynamicDims();

return success();
}

Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,20 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {

// -----

func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
// expected-error@+1 {{different number of scalable dims at source (1) and result (0)}}
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
}

// -----

func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<2x[15]x[2]xf32>) {
// expected-error@+1 {{different number of scalable dims at source (2) and result (1)}}
%0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]xf32>
}

// -----

func.func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
// expected-error@+1 {{'vector.bitcast' invalid kind of type specified}}
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
Expand Down

0 comments on commit 98c73d5

Please sign in to comment.