diff --git a/python/tvm/topi/x86/depthwise_conv2d.py b/python/tvm/topi/x86/depthwise_conv2d.py index 0976c33bbb928..acbe0f70b1d92 100644 --- a/python/tvm/topi/x86/depthwise_conv2d.py +++ b/python/tvm/topi/x86/depthwise_conv2d.py @@ -122,13 +122,18 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides - pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (filter_height, filter_width)) dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) - assert (dh, dw) == (1, 1), "Does not support dilation" - out_height = (in_height - filter_height + pad_top + pad_down) // HSTR + 1 - out_width = (in_width - filter_width + pad_left + pad_right) // WSTR + 1 + dilated_kernel_h = (filter_height - 1) * dh + 1 + dilated_kernel_w = (filter_width - 1) * dw + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + + out_height = (in_height + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (in_width + WPAD - dilated_kernel_w) // WSTR + 1 cfg.define_split("tile_ic", in_channel, num_outputs=2) cfg.define_split("tile_oc", out_channel, num_outputs=2) @@ -140,7 +145,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, te.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype), te.placeholder((out_channel, channel_multiplier, filter_height, filter_width), dtype=kernel.dtype), - strides, padding, out_dtype) + strides, (pad_top, pad_down), out_dtype) if cfg.is_fallback: _fallback_schedule(cfg, wkl) @@ -172,6 +177,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, else: data_pad = data + # depthconv stage idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod @@ -184,7 +190,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, (data_pad[ b, idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block), - oh*HSTR+kh, ow*WSTR+kw, + oh*HSTR+kh*dh, ow*WSTR+kw*dw, idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)] .astype(out_dtype) * kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)), diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 397861713f732..a3b8852b9cc35 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -268,7 +268,6 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m filter_width = filter_height stride_h = stride_w = stride - assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation." assert channel_multiplier == 1, "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1." pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) padding_args = (pad_h, pad_w) @@ -306,7 +305,7 @@ def check_device(device): # declare DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc(Input, Filter, (stride_h, stride_w), - padding_args, + padding, (dilation, dilation), in_layout, out_layout, dtype) @@ -329,8 +328,9 @@ def get_ref_data(): input_np = np.random.uniform(size=input_shape).astype(dtype) filter_np = np.random.uniform(size=filter_shape).astype(dtype) # correctness with scipy + dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype(dtype) depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, filter_np, stride, padding) + input_np, dw_np, stride, padding) relu_scipy = np.maximum(depthwise_conv2d_scipy, 0) return (_transform_data(input_np, ic_block), _transform_kernel(filter_np, oc_block), @@ -389,6 +389,7 @@ def test_depthwise_conv2d(): # depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) # NCHW[x]c + depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME", dilation=2) depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "VALID")