Skip to content

Commit

Permalink
Fix clone producers for use in count region (#14250)
Browse files Browse the repository at this point in the history
Use of cloned values in the `count` region were cloned into the dispatch
region. We should check that their use is in the correct region before
cloning.
  • Loading branch information
rsuderman authored Jun 28, 2023
1 parent 19bd0c4 commit ee53836
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,18 @@ Flow::clonePrecedingOpIntoDispatchRegion(RewriterBase &rewriter,
// Gather all uses of `target`.
SmallVector<OpOperand *> usesInsideOfRegion;
for (OpOperand &use : target->getUses()) {
if (regionOp->isProperAncestor(use.getOwner()))
Operation *parentOperation = use.getOwner();
Region *parentRegion = parentOperation->getParentRegion();

while ((parentOperation = parentOperation->getParentOp())) {
if (regionOp.getOperation() == parentOperation)
break;
parentRegion = parentOperation->getParentRegion();
}

if (parentOperation && &parentRegion->front() == &body) {
usesInsideOfRegion.push_back(&use);
}
}

// Clone op into dispatch region.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ iree_lit_test_suite(
"capture_dispatch_dynamic_dims.mlir",
"cleanup_numeric_narrowing.mlir",
"cleanup_tensor_shapes.mlir",
"clone_producers_into_dispath_regions.mlir",
"clone_producers_into_dispatch_regions.mlir",
"collapse_reduction.mlir",
"conv1x1_to_matmul.mlir",
"convert_region_to_workgroups.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ iree_lit_test_suite(
"capture_dispatch_dynamic_dims.mlir"
"cleanup_numeric_narrowing.mlir"
"cleanup_tensor_shapes.mlir"
"clone_producers_into_dispath_regions.mlir"
"clone_producers_into_dispatch_regions.mlir"
"collapse_linalg_generic_on_tensors.mlir"
"collapse_reduction.mlir"
"conv1x1_to_matmul.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,37 @@ func.func @complex_create(%real : f32, %imag : f32, %input: tensor<4x2xcomplex<f
// CHECK: complex.mul
// CHECK: linalg.yield
// CHECK: flow.return

// ----

#map = affine_map<() -> ()>
func.func @use_in_dispatch_count(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view {
%c1 = arith.constant 1 : index
%c2_i32 = arith.constant 2 : i32
%c0 = arith.constant 0 : index
%0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<1xi32>
%1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<1xi32>
%2 = tensor.empty() : tensor<i32>
%extracted = tensor.extract %0[%c1] : tensor<1xi32>
%4 = flow.dispatch.region -> (tensor<i32>) {
%6 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%2 : tensor<i32>) {
^bb0(%out: i32):
%7 = arith.addi %extracted, %c2_i32 : i32
linalg.yield %7 : i32
} -> tensor<i32>
flow.return %6 : tensor<i32>
} count() -> (index, index, index) {
flow.return %c1, %c1, %c1 : index, index, index
}
%5 = hal.tensor.export %4 "output 0" : tensor<i32> -> !hal.buffer_view
return %5 : !hal.buffer_view
}


// CHECK-LABEL: @use_in_dispatch_count
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: flow.dispatch.region
// CHECK: %[[C1_2:.+]] = arith.constant 1 : index
// CHECK: linalg.generic
// CHECK: count()
// CHECK: flow.return %[[C1]], %[[C1]], %[[C1]]

0 comments on commit ee53836

Please sign in to comment.