Skip to content

Commit

Permalink
Add schedule and test for group convolution (apache#5)
Browse files Browse the repository at this point in the history
* group conv pass all

* pass mobilenet
  • Loading branch information
merrymercy authored and tmoreau89 committed Mar 20, 2019
1 parent 1c8a0aa commit f3aeda7
Show file tree
Hide file tree
Showing 9 changed files with 578 additions and 23 deletions.
1 change: 1 addition & 0 deletions python/tvm/contrib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def which(exec_name):
return full_path
return None


def get_lower_ir(s):
"""Get lower ir code of a schedule.
This is useful for debug, since you don't have to find all inputs/outputs
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
from .group_conv2d import group_conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
Expand Down
74 changes: 74 additions & 0 deletions topi/python/topi/testing/group_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches
"""Convolution in python"""
import numpy as np
import scipy.signal


def group_conv2d_nchw_python(a_np, w_np, stride, padding, groups):
"""Convolution operator in HWCN layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
groups: int
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
num_filter, ci_g, kernel_h, kernel_w = w_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif isinstance(padding, (list, tuple)):
pad_h, pad_w = padding[0] * 2, padding[1] * 2
else:
pad_h = 0 if padding == 'VALID' else kernel_h - 1
pad_w = 0 if padding == 'VALID' else kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
b_np = np.zeros((batch, out_channel, out_height, out_width))

assert ci_g * groups == in_channel

# group computation
for n in range(batch):
for f in range(out_channel):
for c in range(ci_g):
base = f // (out_channel // groups) * ci_g
if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
if pad_h == 0:
apad[:, pad_left:-pad_right] = a_np[n, base + c]
elif pad_w == 0:
apad[pad_top:-pad_bottom, :] = a_np[n, base + c]
else:
apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, base + c]
else:
apad = a_np[n, base + c]
out = scipy.signal.convolve2d(
apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
b_np[n, f] += out[::stride_h, ::stride_w]
return b_np
8 changes: 4 additions & 4 deletions vta/config/vta_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
"GEMM_II" : 1,
"TALU_II" : 2,
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 1,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_OUT_WIDTH" : 3,
"LOG_BATCH" : 0,
"LOG_BLOCK_IN" : 5,
"LOG_BLOCK_OUT" : 5,
"LOG_BLOCK_IN" : 4,
"LOG_BLOCK_OUT" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 16,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
4 changes: 3 additions & 1 deletion vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""TVM TOPI connector, eventually most of these should go to TVM repo"""

from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import vta_conv2d
from . import arm_conv2d

from .bitpack import bitpack
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d
82 changes: 82 additions & 0 deletions vta/python/vta/top/arm_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,88 @@
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic

_WORKLOADS = [
# resnet 18
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),

# mobilenet float32
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),

# mobilenet int8
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
]

_SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),

# float32 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
SpatialPack(1, 4, 8, 8, 8, True),
SpatialPack(2, 2, 8, 1, 1, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(2, 2, 8, 1, 4, False),
SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True),

# int8 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
SpatialPack(1, 4, 8, 8, 8, True),
SpatialPack(2, 2, 8, 1, 1, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(2, 2, 8, 1, 4, False),
SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True),
]

@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
Expand Down
46 changes: 28 additions & 18 deletions vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from nnvm.top import nn as _nn
from ..environment import get_env
from ..ptr_alias import reinterpret
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d


Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
Expand Down Expand Up @@ -262,22 +264,26 @@ def compute_conv2d(attrs, inputs, out):

assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
inputs = list(inputs)
w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH)
assert inputs[1].dtype == "int8"

# Apply bit packing if necessary
if w_pack_factor != 1:
kshape = list(topi.util.get_const_tuple(inputs[1].shape))
kshape[-1] *= w_pack_factor
inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype)

return packed_conv2d(inputs[0], inputs[1],
padding, strides, out_dtype=out_dtype)
if groups == 1:
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
inputs = list(inputs)
w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH)
assert inputs[1].dtype == "int8"

# Apply bit packing if necessary
if w_pack_factor != 1:
kshape = list(topi.util.get_const_tuple(inputs[1].shape))
kshape[-1] *= w_pack_factor
inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype)

return packed_conv2d(inputs[0], inputs[1],
padding, strides, out_dtype=out_dtype)
else:
return packed_group_conv2d(inputs[0], inputs[1],
padding, strides, groups, out_dtype=out_dtype)
return _nn.compute_conv2d(attrs, inputs, out)


Expand All @@ -286,12 +292,16 @@ def schedule_conv2d(attrs, outs, target):
""" 2D convolution schedule.
"""
layout = attrs["layout"]
groups = attrs.get_int('groups')

if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d(outs)
if str(target).startswith("llvm"):
if groups == 1:
return schedule_packed_conv2d(outs)
else:
return schedule_packed_group_conv2d(outs)
elif str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
raise RuntimeError("not support target %s" % target)
return _nn.schedule_conv2d(attrs, outs, target)
Expand Down
Loading

0 comments on commit f3aeda7

Please sign in to comment.