diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp index a8cf5ea66eb3..bf4d840798cc 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp @@ -667,7 +667,13 @@ LogicalResult getOpLaunchConfig(linalg::DepthwiseConvInputNHWCFilterHWCFOp op, .getInt(); const int64_t tileSizeX = 32; int64_t tileSizeY = maxWorkgroupSize / tileSizeX; - SmallVector ts = {1, tileSizeY, tileSizeX}; + SmallVector ts; + if (options.usingLinalgOnTensors) { + // There are five parallel loops in depthwise_conv_2d_input_nhwc_filter_hwcf + ts.assign({0, 0, 1, tileSizeY, tileSizeX}); + } else { + ts.assign({1, tileSizeY, tileSizeX}); + } tileSizes.emplace_back(std::move(ts)); config.workgroupSize = {tileSizeX, tileSizeY, 1}; return success(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp index ee9e20fd3712..431883273003 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp @@ -393,6 +393,7 @@ static void populateTilingConvFilterPatterns( patterns.insert< linalg::LinalgTilingPattern, + linalg::LinalgTilingPattern, linalg::LinalgTilingPattern>( context, tilingOptions, marker); } diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 5feb9471dde9..6d80d7c0c1d1 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -188,7 +188,8 @@ static bool isRootOp(Operation *op) { } } return isa(op); + linalg::DepthwiseConvInputNHWCFilterHWCOp, + linalg::DepthwiseConvInputNHWCFilterHWCFOp>(op); } static bool isAlwaysClonedIntoDispatchOp(Operation *op) { diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir index 6b295e32807d..1acba1328bdb 100644 --- a/iree/test/e2e/xla_ops/convolution.mlir +++ b/iree/test/e2e/xla_ops/convolution.mlir @@ -402,3 +402,27 @@ func @conv2d_1452x2223_dilated_valid() { [-1.122044, -0.41301775, -1.5628793 ]]]]> : tensor<1x2x4x3xf32>) : tensor<1x2x4x3xf32> return } + +func @depthwise_conv_non_1_channel_multiplier() { + %arg0 = iree.unfoldable_constant dense<1.0> : tensor<2x4x5x2xf32> + %arg1 = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3xf32> + %res = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 2 : i64, + padding = dense<0> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> + check.expect_almost_eq_const(%res, dense<4.0> : tensor<2x3x4x6xf32>) : tensor<2x3x4x6xf32> + return +}