diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 8143cc56495a..0c4edbb410e9 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -167,11 +167,10 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.x86") elif layout == "NHWC": assert kernel_layout == "HWOI" - logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.arm_cpu") else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu". format(layout)) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index af5072ef74cd..62bee302984d 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -248,7 +248,13 @@ def is_aarch64_arm(): @qnn_conv2d_legalize.register('arm_cpu') def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): # ARM prefers the dtypes to be same. - if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm(): + is_depthwise = relay.op.strategy.is_depthwise_conv2d(types[0].shape, + attrs['data_layout'], + types[1].shape, + attrs['kernel_layout'], + attrs['groups']) + use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC" + if use_int8_on_arm or is_fast_int8_on_arm(): return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 802b3df19530..07749ee72394 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn from ..util import traverse_inline, get_const_tuple, get_const_int @@ -31,7 +32,6 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype """Compute depthwise_conv2d with NCHW layout""" return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) - @autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu") def schedule_depthwise_conv2d_nchw(cfg, outs): """Schedule depthwise conv2d @@ -181,6 +181,171 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) +@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu") +def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype): + """TOPI compute callback for depthwise_conv2d nhwc + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.te.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + kernel : tvm.te.Tensor + 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] + + strides : list of two ints + [stride_height, stride_width] + + padding : list of two ints + [pad_height, pad_width] + + dilation : list of two ints + [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + + out_dtype = out_dtype or data.dtype + + N, IH, IW, IC = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape) + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + + OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + + if pad_top or pad_left or pad_down or pad_right: + data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], + name="data_pad") + else: + data_pad = data + + output_shape = (N, OH, OW, IC*channel_multiplier) + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + reduce_h = te.reduce_axis((0, KH), name='reduce_h') + reduce_w = te.reduce_axis((0, KW), name='reduce_w') + + out = te.compute(output_shape, lambda n, h, w, c: + te.sum(data_pad[n, + HSTR*h+dilation_h*reduce_h, + w*WSTR+reduce_w*dilation_w, + idxdiv(c, channel_multiplier)].astype(out_dtype) * + kernel[reduce_h, + reduce_w, + idxdiv(c, channel_multiplier), + idxmod(c, channel_multiplier)].astype(out_dtype), + axis=[reduce_h, reduce_w]), + name='depthwise_conv2d_nhwc_output') + return out + +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu") +def schedule_depthwise_conv2d_nhwc(cfg, outs): + """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + out = outs[0] + + ##### space definition begin ##### + n, h, w, c = s[out].op.axis + cfg.define_split('tile_c', c, num_outputs=2) + _, hi = cfg.define_split('tile_h', h, num_outputs=2) + _, wi = cfg.define_split('tile_w', w, num_outputs=2) + cfg.define_knob('locate_output', [0, 1]) + + # fallback support + if cfg.is_fallback: + cfg['tile_c'] = SplitEntity([-1, 8]) + cfg['tile_h'] = SplitEntity([-1, 2]) + cfg['tile_w'] = SplitEntity([-1, 2]) + cfg['locate_output'] = OtherOptionEntity(1) + ##### space definition end ##### + + def schedule_conv(conv): + conv_data = conv.op.input_tensors[0] + + n, w, h, c = conv.op.axis + r_h, r_w = conv.op.reduce_axis + ho, hi = cfg['tile_h'].apply(s, conv, h) + wo, wi = cfg['tile_w'].apply(s, conv, w) + co, ci = cfg['tile_c'].apply(s, conv, c) + + if conv_data.name == "data_pad": + assert isinstance(conv_data.op, tvm.te.ComputeOp) + # Define a policy for padding computation + cfg.define_knob('data_pad_inline', [1, 2, 3]) + if cfg.is_fallback: + cfg['data_pad_inline'] = OtherOptionEntity(3) + if cfg['data_pad_inline'].val == 1: + s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) + s[conv_data].compute_at(s[conv], ho) + if cfg['data_pad_inline'].val == 2: + s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) + s[conv_data].compute_at(s[conv], wo) + if cfg['data_pad_inline'].val == 3: + s[conv_data].compute_inline() + + s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci) + fused_n_ho = s[conv].fuse(n, ho) + s[conv].vectorize(ci) + return fused_n_ho + + def schedule_conv_out(out): + n, h, w, c = out.op.axis + co, ci = cfg['tile_c'].apply(s, out, c) + wo, wi = cfg['tile_w'].apply(s, out, w) + ho, hi = cfg['tile_h'].apply(s, out, h) + s[out].reorder(n, ho, wo, co, hi, wi) + + if out.dtype in ['int8', 'uint8']: + # In case of quantized convolution further split the channel in batches of 4 elements + # so that we can use arm intrinsics to run fixed_point_multiplication + ci_outer, ci_inner = s[out].split(ci, 4) + s[out].vectorize(ci_inner) + + fused_n_ho = s[out].fuse(n, ho) + return hi, wi, fused_n_ho + + def _callback(op): + if op.name == 'depthwise_conv2d_nhwc_output': + conv = op.output(0) + if conv != out: + hi, wi, p_axis = schedule_conv_out(out) + schedule_conv(conv) + if cfg['locate_output'].val == 0: + s[conv].compute_at(s[out], hi) + if cfg['locate_output'].val == 1: + s[conv].compute_at(s[out], wi) + else: + p_axis = schedule_conv(out) + + s[out].parallel(p_axis) + + traverse_inline(s, outs[0].op, _callback) + return s @autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu") def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs): diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 397861713f73..93a166d3e216 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -40,6 +40,7 @@ _depthwise_conv2d_nhwc_implement = { "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), + "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc), "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), }