Skip to content

Commit

Permalink
Plumb pooling ops through Linalg on tensors path. (#5136)
Browse files Browse the repository at this point in the history
1) Set pooling ops as root op.
2) Convert linalg.init_tensor to a AllocOp.
3) Set pooling ops launch config for Linalg on tensors path.

The (2) is only created for shaped operand. The alloc op will be deleted
once lower to loops. This is the same behavior as Linalg on buffers
path.

Part of #5043
  • Loading branch information
hanhanW authored Mar 22, 2021
1 parent a33f72f commit 963beb9
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 10 deletions.
19 changes: 19 additions & 0 deletions iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,22 @@ static LogicalResult convertConstantOp(OpBuilder &b, ConstantOp constantOp,
return success();
}

/// Converts a linalg.init_tensor op to memref.alloc op. This provides a shaped
/// operand for pooling ops. The op will be deleted after going to loops.
static LogicalResult convertInitTensorOp(
OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn,
linalg::InitTensorOp initTensorOp, BlockAndValueMapping &bvm) {
if (bvm.contains(initTensorOp.getResult())) return success();
RankedTensorType tensorType = initTensorOp.getType();
OpBuilder::InsertionGuard g(b);
b.setInsertionPointAfter(initTensorOp);
Value alloc = allocationFn(b, initTensorOp.getLoc(), tensorType.getShape(),
tensorType.getElementType(),
llvm::to_vector<4>(initTensorOp.sizes()));
bvm.map(initTensorOp.getResult(), alloc);
return success();
}

/// Avoids creating an allocation if the result tensor can just be aliased to
/// use the same buffer (`inputBuffer`) that `srcTensor` is mapped to. This can
/// be done if `srcTensor` has a single use, which is the operation which is
Expand Down Expand Up @@ -668,6 +684,9 @@ void LinalgBufferizePass::runOnFunction() {
.Case<linalg::TensorReshapeOp>([&](linalg::TensorReshapeOp reshapeOp) {
return convertTensorReshapeOp(b, allocationFn, reshapeOp, bvm);
})
.Case<linalg::InitTensorOp>([&](linalg::InitTensorOp initTensorOp) {
return convertInitTensorOp(b, allocationFn, initTensorOp, bvm);
})
.Case<tensor::ExtractOp>([&](tensor::ExtractOp extractOp) {
return convertTensorExtractOp(b, extractOp, bvm);
})
Expand Down
41 changes: 41 additions & 0 deletions iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,44 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} {
// CHECK: linalg.indexed_generic
// CHECK: %[[VAL:.+]] = memref.load %[[ARG0]]
// CHECK: linalg.yield %[[VAL]]

// -----

func @pooling_nhwc_sum() {
%c2 = constant 2 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : !flow.dispatch.tensor<readonly:f32>
%1 = hal.interface.binding.subspan @legacy_io::@ro1[%c0] : !flow.dispatch.tensor<readonly:1x4x6x1xf32>
%2 = hal.interface.binding.subspan @legacy_io::@wo2[%c0] : !flow.dispatch.tensor<writeonly:1x2x2x1xf32>
%3 = linalg.init_tensor [2, 3] : tensor<2x3xf32>
%4 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor<readonly:f32> -> tensor<f32>
%5 = tensor.extract %4[] : tensor<f32>
%6 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor<readonly:1x4x6x1xf32> -> tensor<1x4x6x1xf32>
%7 = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
%8 = linalg.fill(%7, %5) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
%9 = linalg.pooling_nhwc_sum {
dilations = dense<1> : vector<2xi64>,
strides = dense<[2, 3]> : vector<2xi64>
} ins(%6, %3 : tensor<1x4x6x1xf32>, tensor<2x3xf32>)
outs(%8 : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
flow.dispatch.tensor.store %9, %2 : tensor<1x2x2x1xf32> -> !flow.dispatch.tensor<writeonly:1x2x2x1xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ro1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @wo2, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
// CHECK-LABEL: func @pooling_nhwc_sum
// CHECK-DAG: %[[INPUT:.+]] = hal.interface.binding.subspan @legacy_io::@ro1[%c0] : memref<1x4x6x1xf32>
// CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : memref<f32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @legacy_io::@wo2[%c0] : memref<1x2x2x1xf32>
// CHECK: %[[WINDOW:.+]] = memref.alloc() : memref<2x3xf32>
// CHECK: %[[INIT_VAL:.+]] = memref.load %[[INIT]][] : memref<f32>
// CHECK: linalg.fill(%[[RET0]], %[[INIT_VAL]]) : memref<1x2x2x1xf32>, f32
// CHECK: linalg.pooling_nhwc_sum
// CHECK-SAME: dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<[2, 3]> : vector<2xi64>
// CHECK-SAME: ins(%[[INPUT]], %[[WINDOW]] : memref<1x4x6x1xf32>, memref<2x3xf32>)
// CHECK-SAME: outs(%[[RET0]] : memref<1x2x2x1xf32>)
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,13 @@ LogicalResult getInputOutputTypesForAllTiles(
linalg::LinalgOp rootOp, SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes) {
for (Value inputBuffer : rootOp.getInputBuffers()) {
auto subviewOp = inputBuffer.getDefiningOp<memref::SubViewOp>();
if (!subviewOp) return failure();
inputTypes.push_back(subviewOp.getViewSource().getType());
if (auto subviewOp = inputBuffer.getDefiningOp<memref::SubViewOp>()) {
inputTypes.push_back(subviewOp.getViewSource().getType());
} else if (auto allocOp = inputBuffer.getDefiningOp<memref::AllocOp>()) {
inputTypes.push_back(allocOp.getType());
} else {
return failure();
}
}

for (Value outputBuffer : rootOp.getOutputBuffers()) {
Expand Down
10 changes: 6 additions & 4 deletions iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,14 @@ static LogicalResult getPoolingOpLaunchConfig(
// be able to figure out which dimensions of the output correspond to the
// pooled dimension and which are not. Need to fix that, but for now just use
// a working heuristic.
SmallVector<int64_t, 4> ts(std::min<int64_t>(
op.getOutput(0).getType().template cast<ShapedType>().getRank(), 3));
const int64_t tileSizeX = 32;
int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
ts[ts.size() - 2] = tileSizeY;
ts[ts.size() - 1] = tileSizeX;
SmallVector<int64_t, 4> ts;
if (options.usingLinalgOnTensors) {
ts.assign({0, tileSizeY, tileSizeX, 1});
} else {
ts.assign({0, tileSizeY, tileSizeX});
}
tileSizes.emplace_back(std::move(ts));
config.workgroupSize = {tileSizeX, tileSizeY, 1};
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ static bool isRootOp(Operation *op) {
}
return isa<linalg::ConvInputNHWCFilterHWCFOp,
linalg::DepthwiseConvInputNHWCFilterHWCOp,
linalg::DepthwiseConvInputNHWCFilterHWCFOp>(op);
linalg::DepthwiseConvInputNHWCFilterHWCFOp,
linalg::PoolingNHWCSumOp, linalg::PoolingNHWCMaxOp,
linalg::PoolingNHWCMinOp>(op);
}

static bool isAlwaysClonedIntoDispatchOp(Operation *op) {
Expand Down
4 changes: 2 additions & 2 deletions iree/test/e2e/xla_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ iree_check_single_backend_test_suite(
# https://github.com/google/iree/issues/4079
"pad.mlir",
"reduce.mlir",
# "reduce_window.mlir",
"reduce_window.mlir",
"remainder.mlir",
"reshape.mlir",
"reverse.mlir",
Expand Down Expand Up @@ -279,7 +279,7 @@ iree_check_single_backend_test_suite(
"negate.mlir",
"pad.mlir",
"reduce.mlir",
# "reduce_window.mlir",
"reduce_window.mlir",
"remainder.mlir",
"reshape.mlir",
"reverse.mlir",
Expand Down
2 changes: 2 additions & 0 deletions iree/test/e2e/xla_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ iree_check_single_backend_test_suite(
"negate.mlir"
"pad.mlir"
"reduce.mlir"
"reduce_window.mlir"
"remainder.mlir"
"reshape.mlir"
"reverse.mlir"
Expand Down Expand Up @@ -257,6 +258,7 @@ iree_check_single_backend_test_suite(
"negate.mlir"
"pad.mlir"
"reduce.mlir"
"reduce_window.mlir"
"remainder.mlir"
"reshape.mlir"
"reverse.mlir"
Expand Down

0 comments on commit 963beb9

Please sign in to comment.