Skip to content

Commit

Permalink
Improve NHWC depthwise convolution for AArch64 (apache#6095)
Browse files Browse the repository at this point in the history
* Improve NHWC depthwise convolution for aarch64

We created a default schedule (no auto-tuning or tensorization) named
depthwise_conv2d_nhwc which does a decent job at optimizing depthwise
for NHWC layouts (on aarch64).

Change-Id: I01e32903f6c1950623f33eae18484e70244fe0af

* Add tuning knobs in depthwise schedule

Change-Id: I15080e7f12b16e6c6aba99a04e42023845eeabf1

* Introduce padding policy

Change-Id: If12a6d05dce9153861550ddef1ee5216809dd1e1

* Vectorize padding

Change-Id: I7e2062a40358bf111c0366a449945eb077fb2e30

* Legalize depthwise convolution (2x improvement) and fix tuning issue

Change-Id: I4b82c58b167e40b0b7747d28293bbb488c505dd9

* Adding assert on padding

Change-Id: Idf8eeaaface5eb7799109cd00f437e404778b9cd

* Fix python linting

Change-Id: Iac16a8daea1268f0eb331fe4ec18a62408106cf9

* Removing commented code

Change-Id: I1412f22ad9864273d77a7bf38a6768694339b7f0

* Revert test file to make CI pass

Change-Id: Ica3eff8f9f0fd4c6f32f7ae80adc922f8b16cec9

* Enabling only arm_cpu tests

Change-Id: Icbaafcb39e892a5d1a4685133c1699e4d1a8e07e

* Rebasing

Change-Id: Ibb23f1d4e0d0107e4e3b3571437161cdc2ee2909
  • Loading branch information
Giuseppe Rossini authored and Trevor Morris committed Aug 26, 2020
1 parent 90c6dcd commit 0f95522
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 6 deletions.
7 changes: 3 additions & 4 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,10 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="depthwise_conv2d_nchw.x86")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic")
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.arm_cpu")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".
format(layout))
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,13 @@ def is_aarch64_arm():
@qnn_conv2d_legalize.register('arm_cpu')
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm():
is_depthwise = relay.op.strategy.is_depthwise_conv2d(types[0].shape,
attrs['data_layout'],
types[1].shape,
attrs['kernel_layout'],
attrs['groups'])
use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC"
if use_int8_on_arm or is_fast_int8_on_arm():
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)

Expand Down
167 changes: 166 additions & 1 deletion python/tvm/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity

from .. import nn
from ..util import traverse_inline, get_const_tuple, get_const_int
Expand All @@ -31,7 +32,6 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype
"""Compute depthwise_conv2d with NCHW layout"""
return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu")
def schedule_depthwise_conv2d_nchw(cfg, outs):
"""Schedule depthwise conv2d
Expand Down Expand Up @@ -181,6 +181,171 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila

return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)

@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
"""TOPI compute callback for depthwise_conv2d nhwc
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.te.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
kernel : tvm.te.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
strides : list of two ints
[stride_height, stride_width]
padding : list of two ints
[pad_height, pad_width]
dilation : list of two ints
[dilation_height, dilation_width]
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.te.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""

out_dtype = out_dtype or data.dtype

N, IH, IW, IC = get_const_tuple(data.shape)

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)

dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1

pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)

OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1

if pad_top or pad_left or pad_down or pad_right:
data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
name="data_pad")
else:
data_pad = data

output_shape = (N, OH, OW, IC*channel_multiplier)

idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod

reduce_h = te.reduce_axis((0, KH), name='reduce_h')
reduce_w = te.reduce_axis((0, KW), name='reduce_w')

out = te.compute(output_shape, lambda n, h, w, c:
te.sum(data_pad[n,
HSTR*h+dilation_h*reduce_h,
w*WSTR+reduce_w*dilation_w,
idxdiv(c, channel_multiplier)].astype(out_dtype) *
kernel[reduce_h,
reduce_w,
idxdiv(c, channel_multiplier),
idxmod(c, channel_multiplier)].astype(out_dtype),
axis=[reduce_h, reduce_w]),
name='depthwise_conv2d_nhwc_output')
return out

@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
def schedule_depthwise_conv2d_nhwc(cfg, outs):
"""Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
out = outs[0]

##### space definition begin #####
n, h, w, c = s[out].op.axis
cfg.define_split('tile_c', c, num_outputs=2)
_, hi = cfg.define_split('tile_h', h, num_outputs=2)
_, wi = cfg.define_split('tile_w', w, num_outputs=2)
cfg.define_knob('locate_output', [0, 1])

# fallback support
if cfg.is_fallback:
cfg['tile_c'] = SplitEntity([-1, 8])
cfg['tile_h'] = SplitEntity([-1, 2])
cfg['tile_w'] = SplitEntity([-1, 2])
cfg['locate_output'] = OtherOptionEntity(1)
##### space definition end #####

def schedule_conv(conv):
conv_data = conv.op.input_tensors[0]

n, w, h, c = conv.op.axis
r_h, r_w = conv.op.reduce_axis
ho, hi = cfg['tile_h'].apply(s, conv, h)
wo, wi = cfg['tile_w'].apply(s, conv, w)
co, ci = cfg['tile_c'].apply(s, conv, c)

if conv_data.name == "data_pad":
assert isinstance(conv_data.op, tvm.te.ComputeOp)
# Define a policy for padding computation
cfg.define_knob('data_pad_inline', [1, 2, 3])
if cfg.is_fallback:
cfg['data_pad_inline'] = OtherOptionEntity(3)
if cfg['data_pad_inline'].val == 1:
s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
s[conv_data].compute_at(s[conv], ho)
if cfg['data_pad_inline'].val == 2:
s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
s[conv_data].compute_at(s[conv], wo)
if cfg['data_pad_inline'].val == 3:
s[conv_data].compute_inline()

s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
fused_n_ho = s[conv].fuse(n, ho)
s[conv].vectorize(ci)
return fused_n_ho

def schedule_conv_out(out):
n, h, w, c = out.op.axis
co, ci = cfg['tile_c'].apply(s, out, c)
wo, wi = cfg['tile_w'].apply(s, out, w)
ho, hi = cfg['tile_h'].apply(s, out, h)
s[out].reorder(n, ho, wo, co, hi, wi)

if out.dtype in ['int8', 'uint8']:
# In case of quantized convolution further split the channel in batches of 4 elements
# so that we can use arm intrinsics to run fixed_point_multiplication
ci_outer, ci_inner = s[out].split(ci, 4)
s[out].vectorize(ci_inner)

fused_n_ho = s[out].fuse(n, ho)
return hi, wi, fused_n_ho

def _callback(op):
if op.name == 'depthwise_conv2d_nhwc_output':
conv = op.output(0)
if conv != out:
hi, wi, p_axis = schedule_conv_out(out)
schedule_conv(conv)
if cfg['locate_output'].val == 0:
s[conv].compute_at(s[out], hi)
if cfg['locate_output'].val == 1:
s[conv].compute_at(s[out], wi)
else:
p_axis = schedule_conv(out)

s[out].parallel(p_axis)

traverse_inline(s, outs[0].op, _callback)
return s

@autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
Expand Down
1 change: 1 addition & 0 deletions tests/python/topi/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

_depthwise_conv2d_nhwc_implement = {
"generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc),
"arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
"gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc),
}

Expand Down

0 comments on commit 0f95522

Please sign in to comment.