diff --git a/topi/python/topi/generic/conv2d.py b/topi/python/topi/generic/conv2d.py index 69984a169ac67..2d9f78b645db7 100644 --- a/topi/python/topi/generic/conv2d.py +++ b/topi/python/topi/generic/conv2d.py @@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out, parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) s[data_vec].parallel(parallel_axis) - oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis + # conv2d_nchwc_int8 has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) oc_bn = cfg["tile_oc"].size[-1] if oc_bn > 1: @@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out, s[CC].unroll(oc_f_inner) if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=reg_n) - s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) return s @@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out, parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) s[data_vec].parallel(parallel_axis) - oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis + # Conv2d int8 schedule has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) oc_bn = cfg["tile_oc"].size[-1] if oc_bn > 1: @@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out, s[CC].unroll(oh_inner) if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) - s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - - parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) return s