Skip to content

Commit

Permalink
[VM] Fix the shape function of conv nhwc (#8480)
Browse files Browse the repository at this point in the history
* Add dynamic support for conv2d nhwc
  • Loading branch information
jcf94 committed Jul 16, 2021
1 parent 2b57907 commit cba9cf3
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 17 deletions.
55 changes: 43 additions & 12 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,8 @@ def compute_space_to_depth(attrs, inputs, out_dtype):


@script
def _conv_shape_func(dshape, kshape, strides, padding, dilation):
def _conv_shape_func_nchw(dshape, kshape, strides, padding, dilation):
"""Shape function for conv*d op with nchw & oihw layout."""
out = output_tensor((dshape.shape[0],), "int64")
out[0] = dshape[0]
out[1] = kshape[0]
Expand All @@ -975,23 +976,52 @@ def _conv_shape_func(dshape, kshape, strides, padding, dilation):
return out


@script
def _conv_shape_func_nhwc_hwio(dshape, kshape, strides, padding, dilation):
"""Shape function for conv*d op with nhwc & hwio layout."""
out = output_tensor((dshape.shape[0],), "int64")
out[0] = dshape[0]
out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 1]

for i in const_range(dshape.shape[0] - 2):
dilated_k = (kshape[i] - 1) * dilation[i] + 1
out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1
return out


@script
def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation):
"""Shape function for conv*d op with nhwc & hwoi layout."""
out = output_tensor((dshape.shape[0],), "int64")
out[0] = dshape[0]
out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 2]

for i in const_range(dshape.shape[0] - 2):
dilated_k = (kshape[i] - 1) * dilation[i] + 1
out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1
return out


def conv_shape_func(attrs, inputs, _):
"""
Shape function for contrib_conv2d_NCHWc op.
"""
"""Shape function for conv*d op."""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)

return [
_conv_shape_func(
inputs[0],
inputs[1],
convert(strides),
convert(padding),
convert(dilation),
shape_func = None
if attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
shape_func = _conv_shape_func_nchw
elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
shape_func = _conv_shape_func_nhwc_hwio
elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI":
shape_func = _conv_shape_func_nhwc_hwoi
else:
raise ValueError(
"Unsupported data/kernel layout: %s, %s"
% (attrs["data_layout"], attrs["kernel_layout"])
)
]

return [shape_func(inputs[0], inputs[1], convert(strides), convert(padding), convert(dilation))]


reg.register_shape_func("nn.conv1d", False, conv_shape_func)
Expand Down Expand Up @@ -1307,4 +1337,5 @@ def dilate_shape_func(attrs, inputs, _):

reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
reg.register_shape_func("nn.fast_softmax", False, elemwise_shape_func)
reg.register_shape_func("nn.relu", False, elemwise_shape_func)
4 changes: 3 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d_spatial_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_
data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0])

# ==================== define configuration space ====================
n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
# If it has dynamic shape in batch, we fix the split factor to 1
n = cfg.axis(N) if isinstance(N, int) else cfg.axis(1)
oc, oh, ow = cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)

if num_tile == 2: # for arm cpu
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/cuda/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
AL = s.cache_read(AA, "local", [OL])
WL = s.cache_read(WW, "local", [OL])

# Currently Conv2d NHWC only support dynamic shpe in batch
dynamic_batch = isinstance(s[output].op.axis[0].dom.extent, tvm.tir.expr.Var)

# Schedule for autotvm
cfg.define_knob("tile_n", [2, 4, 8])
cfg.define_knob("tile_n", [1] if dynamic_batch else [2, 4, 8])
cfg.define_knob("tile_c", [2, 4, 8])
cfg.define_knob("num_thread_n", [4, 8, 16])
cfg.define_knob("num_thread_n", [1] if dynamic_batch else [4, 8, 16])
cfg.define_knob("num_thread_c", [4, 8, 16])
cfg.define_knob("vthread_n", [1, 2])
cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2])
cfg.define_knob("vthread_c", [1, 2])
cfg.define_knob("step", [16, 3, 32, 64])

Expand Down
35 changes: 34 additions & 1 deletion tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,24 @@ def verify_any_conv2d(
dilation,
static_data_shape,
ref_out_shape,
data_layout="NCHW",
kernel_layout="OIHW",
use_cudnn=False,
):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype)
y = relay.nn.conv2d(data, kernel, strides, padding, dilation, kernel_size=kernel_shape[2:4])
y = relay.nn.conv2d(
data,
kernel,
strides,
padding,
dilation,
kernel_size=kernel_shape[2:4] if kernel_layout == "OIHW" else kernel_shape[0:2],
data_layout=data_layout,
kernel_layout=kernel_layout,
)
mod["main"] = relay.Function([data, kernel], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
Expand Down Expand Up @@ -545,6 +556,28 @@ def test_any_conv2d():
(1, 64, 224, 224),
use_cudnn=True,
)
verify_any_conv2d(
(relay.Any(), 224, 224, 64),
(3, 3, 64, 64),
(1, 1),
(1, 1),
(1, 1),
(1, 224, 224, 64),
(1, 224, 224, 64),
data_layout="NHWC",
kernel_layout="HWIO",
)
verify_any_conv2d(
(relay.Any(), 224, 224, 64),
(3, 3, 64, 64),
(1, 1),
(1, 1),
(2, 2),
(2, 224, 224, 64),
(2, 222, 222, 64),
data_layout="NHWC",
kernel_layout="HWIO",
)


def verify_any_conv2d_NCHWc(
Expand Down

0 comments on commit cba9cf3

Please sign in to comment.