Skip to content

Commit

Permalink
Improve NHWC depthwise convolution for aarch64
Browse files Browse the repository at this point in the history
We created a default schedule (no auto-tuning or tensorization) named
depthwise_conv2d_nhwc which does a decent job at optimizing depthwise
for NHWC layouts (on aarch64).

Change-Id: I01e32903f6c1950623f33eae18484e70244fe0af
  • Loading branch information
Giuseppe Rossini committed Jul 20, 2020
1 parent 414c61d commit 1ca776e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 8 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="depthwise_conv2d_nchw.x86")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
#logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic")
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.arm_cpu")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".
format(layout))
Expand Down
1 change: 0 additions & 1 deletion src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ bool ReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name) {
std::cout << "making " << op_name << std::endl;
auto attrs = make_object<ReduceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
Expand Down
125 changes: 124 additions & 1 deletion topi/python/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype
"""Compute depthwise_conv2d with NCHW layout"""
return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu")
def schedule_depthwise_conv2d_nchw(cfg, outs):
"""Schedule depthwise conv2d
Expand Down Expand Up @@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila

return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)

@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
"""TOPI compute callback for depthwise_conv2d nhwc
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.te.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
kernel : tvm.te.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
dilation : list of two ints
[dilation_height, dilation_width]
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.te.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""

out_dtype = out_dtype or data.dtype

N, IH, IW, IC = get_const_tuple(data.shape)

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)

dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1

pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)

OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1

if pad_top or pad_left:
data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
name="data_pad")
else:
data_pad = data

output_shape = (N, OH, OW, IC*channel_multiplier)

idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod

reduce_h = te.reduce_axis((0, KH), name='reduce_h')
reduce_w = te.reduce_axis((0, KW), name='reduce_w')

out = te.compute(output_shape, lambda n, h, w, c:
te.sum(data_pad[n,
HSTR*h+dilation_h*reduce_h,
w*WSTR+reduce_w*dilation_w,
idxdiv(c, channel_multiplier)].astype(out_dtype) *
kernel[reduce_h,
reduce_w,
idxdiv(c, channel_multiplier),
idxmod(c, channel_multiplier)].astype(out_dtype),
axis=[reduce_h, reduce_w]),
name='depthwise_conv2d_nhwc_output')

return out

@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
def schedule_depthwise_conv2d_nhwc(_, outs):
"""Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
out = outs[0]

def schedule_conv(conv):
n, w, h, c = conv.op.axis
r_h, r_w = conv.op.reduce_axis
co, ci = s[conv].split(c, 8)
wo, wi = s[conv].split(w, 2)
ho, hi = s[conv].split(h, 2)

s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
s[conv].parallel(wo)
s[conv].vectorize(ci)

def schedule_conv_out(out):
n, h, w, c = out.op.axis
co, ci = s[out].split(c, 8)
wo, wi = s[out].split(w, 2)
ho, hi = s[out].split(h, 2)
ci_outer, ci_inner = s[out].split(ci, 4)
s[out].reorder(n, wo, ho, co, wi, hi)
s[out].vectorize(ci_inner)
compute_at_axis = hi
s[out].parallel(wo)
return compute_at_axis

def _callback(op):
if op.name == 'depthwise_conv2d_nhwc_output':
conv = op.output(0)
if conv != out:
compute_at_axis = schedule_conv_out(out)
schedule_conv(conv)
s[conv].compute_at(s[out], compute_at_axis)
else:
schedule_conv(out)

traverse_inline(s, outs[0].op, _callback)
return s

@autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
Expand Down
7 changes: 5 additions & 2 deletions topi/tests/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

_depthwise_conv2d_nhwc_implement = {
"generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc),
"arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
"gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc),
}

Expand Down Expand Up @@ -177,6 +178,9 @@ def check_device(device):
print("Running on target: %s" % device)

fcompute, fschedule = topi.testing.dispatch(device, _depthwise_conv2d_nhwc_implement)
if device == "gpu" and dilation > 1:
# skip because it uses too large shared memory on cuda
return
with tvm.target.create(device):
# declare
DepthwiseConv2d = fcompute(Input, Filter,
Expand Down Expand Up @@ -385,8 +389,7 @@ def test_depthwise_conv2d():
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID")
# dilation = 2
# disabled because it uses too large shared memory on cuda
# depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)

# NCHW[x]c
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME")
Expand Down

0 comments on commit 1ca776e

Please sign in to comment.