diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 753a17605667..b0b5a569f9e0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -964,7 +964,8 @@ def compute_space_to_depth(attrs, inputs, out_dtype): @script -def _conv_shape_func(dshape, kshape, strides, padding, dilation): +def _conv_shape_func_nchw(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nchw & oihw layout.""" out = output_tensor((dshape.shape[0],), "int64") out[0] = dshape[0] out[1] = kshape[0] @@ -975,23 +976,52 @@ def _conv_shape_func(dshape, kshape, strides, padding, dilation): return out +@script +def _conv_shape_func_nhwc_hwio(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & hwio layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 1] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + +@script +def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & hwoi layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 2] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + def conv_shape_func(attrs, inputs, _): - """ - Shape function for contrib_conv2d_NCHWc op. - """ + """Shape function for conv*d op.""" strides = get_const_tuple(attrs.strides) padding = get_const_tuple(attrs.padding) dilation = get_const_tuple(attrs.dilation) - return [ - _conv_shape_func( - inputs[0], - inputs[1], - convert(strides), - convert(padding), - convert(dilation), + shape_func = None + if attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW": + shape_func = _conv_shape_func_nchw + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO": + shape_func = _conv_shape_func_nhwc_hwio + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI": + shape_func = _conv_shape_func_nhwc_hwoi + else: + raise ValueError( + "Unsupported data/kernel layout: %s, %s" + % (attrs["data_layout"], attrs["kernel_layout"]) ) - ] + + return [shape_func(inputs[0], inputs[1], convert(strides), convert(padding), convert(dilation))] reg.register_shape_func("nn.conv1d", False, conv_shape_func) @@ -1307,4 +1337,5 @@ def dilate_shape_func(attrs, inputs, _): reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) +reg.register_shape_func("nn.fast_softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) diff --git a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py index 7651305ab2dd..4eed56a22572 100644 --- a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py +++ b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py @@ -273,7 +273,9 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0]) # ==================== define configuration space ==================== - n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) + # If it has dynamic shape in batch, we fix the split factor to 1 + n = cfg.axis(N) if isinstance(N, int) else cfg.axis(1) + oc, oh, ow = cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW) if num_tile == 2: # for arm cpu diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/cuda/conv2d_nhwc.py index 991585587bbf..e4361e30b5c3 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/cuda/conv2d_nhwc.py @@ -43,12 +43,15 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): AL = s.cache_read(AA, "local", [OL]) WL = s.cache_read(WW, "local", [OL]) + # Currently Conv2d NHWC only support dynamic shpe in batch + dynamic_batch = isinstance(s[output].op.axis[0].dom.extent, tvm.tir.expr.Var) + # Schedule for autotvm - cfg.define_knob("tile_n", [2, 4, 8]) + cfg.define_knob("tile_n", [1] if dynamic_batch else [2, 4, 8]) cfg.define_knob("tile_c", [2, 4, 8]) - cfg.define_knob("num_thread_n", [4, 8, 16]) + cfg.define_knob("num_thread_n", [1] if dynamic_batch else [4, 8, 16]) cfg.define_knob("num_thread_c", [4, 8, 16]) - cfg.define_knob("vthread_n", [1, 2]) + cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) cfg.define_knob("vthread_c", [1, 2]) cfg.define_knob("step", [16, 3, 32, 64]) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index e94b5145ccc2..f871c67ad703 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -496,13 +496,24 @@ def verify_any_conv2d( dilation, static_data_shape, ref_out_shape, + data_layout="NCHW", + kernel_layout="OIHW", use_cudnn=False, ): mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype) - y = relay.nn.conv2d(data, kernel, strides, padding, dilation, kernel_size=kernel_shape[2:4]) + y = relay.nn.conv2d( + data, + kernel, + strides, + padding, + dilation, + kernel_size=kernel_shape[2:4] if kernel_layout == "OIHW" else kernel_shape[0:2], + data_layout=data_layout, + kernel_layout=kernel_layout, + ) mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) @@ -545,6 +556,28 @@ def test_any_conv2d(): (1, 64, 224, 224), use_cudnn=True, ) + verify_any_conv2d( + (relay.Any(), 224, 224, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + (1, 1), + (1, 224, 224, 64), + (1, 224, 224, 64), + data_layout="NHWC", + kernel_layout="HWIO", + ) + verify_any_conv2d( + (relay.Any(), 224, 224, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + (2, 2), + (2, 224, 224, 64), + (2, 222, 222, 64), + data_layout="NHWC", + kernel_layout="HWIO", + ) def verify_any_conv2d_NCHWc(