Skip to content

Commit

Permalink
Plumb depthwise convolution in Linalg on tensors. (iree-org#5189)
Browse files Browse the repository at this point in the history
Supports general 2D depthwise convolution in Linalg on tensors.
Previously, only the case of channel_multiplier=1 was supported.

This sets `linalg::DepthwiseConvInputNHWCFilterHWCFOp` to be
root op. And adds an end-to-end test case.
  • Loading branch information
hanhanW authored Mar 22, 2021
1 parent 2864ee4 commit a33f72f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,13 @@ LogicalResult getOpLaunchConfig(linalg::DepthwiseConvInputNHWCFilterHWCFOp op,
.getInt();
const int64_t tileSizeX = 32;
int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
SmallVector<int64_t, 4> ts = {1, tileSizeY, tileSizeX};
SmallVector<int64_t, 4> ts;
if (options.usingLinalgOnTensors) {
// There are five parallel loops in depthwise_conv_2d_input_nhwc_filter_hwcf
ts.assign({0, 0, 1, tileSizeY, tileSizeX});
} else {
ts.assign({1, tileSizeY, tileSizeX});
}
tileSizes.emplace_back(std::move(ts));
config.workgroupSize = {tileSizeX, tileSizeY, 1};
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ static void populateTilingConvFilterPatterns(

patterns.insert<
linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
context, tilingOptions, marker);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ static bool isRootOp(Operation *op) {
}
}
return isa<linalg::ConvInputNHWCFilterHWCFOp,
linalg::DepthwiseConvInputNHWCFilterHWCOp>(op);
linalg::DepthwiseConvInputNHWCFilterHWCOp,
linalg::DepthwiseConvInputNHWCFilterHWCFOp>(op);
}

static bool isAlwaysClonedIntoDispatchOp(Operation *op) {
Expand Down
24 changes: 24 additions & 0 deletions iree/test/e2e/xla_ops/convolution.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,27 @@ func @conv2d_1452x2223_dilated_valid() {
[-1.122044, -0.41301775, -1.5628793 ]]]]> : tensor<1x2x4x3xf32>) : tensor<1x2x4x3xf32>
return
}

func @depthwise_conv_non_1_channel_multiplier() {
%arg0 = iree.unfoldable_constant dense<1.0> : tensor<2x4x5x2xf32>
%arg1 = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3xf32>
%res = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 2 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
check.expect_almost_eq_const(%res, dense<4.0> : tensor<2x3x4x6xf32>) : tensor<2x3x4x6xf32>
return
}

0 comments on commit a33f72f

Please sign in to comment.