Skip to content

Commit

Permalink
Rebase, and update responding to some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Wheest committed Jan 6, 2021
1 parent 88c3809 commit 3cf4db1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 18 deletions.
10 changes: 5 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
assert kernel_layout == "HWIO"
if not is_auto_scheduler_enabled():
logger.warning("group_conv2d is not optimized for x86 with autotvm.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc),
name="group_conv2d_nhwc.generic",
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc),
name="group_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/topi/arm_cpu/group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from tvm import te
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity

from ..util import get_const_tuple
from ..utils import get_const_tuple
from ..nn.pad import pad
from .. import tag

from ..nn.util import infer_pad
from ..nn.utils import infer_pad
from ..nn.conv2d import _get_workload as _get_conv2d_workload


Expand Down Expand Up @@ -62,8 +62,8 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype,
def _fallback_schedule(cfg, wkl):
simd_width = 4 # assume ARM SIMD Width is 4
pad_left, pad_right = wkl.padl, wkl.padr
stride_w = wkl.wstride
out_width = (wkl.width + pad_left + pad_right - wkl.wkernel) // stride_w + 1
stride_w = wkl.stride_w
out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1
groups = wkl.groups
kernels_per_group = wkl.out_filter // groups
kernel_depth = wkl.in_filter // groups
Expand Down
4 changes: 0 additions & 4 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,10 @@ def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layou
else:
KH, KW, CIG, CO = get_const_tuple(kernel.shape)

<<<<<<< HEAD
pt, pl, pb, pr = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
dilation_h, dilation_w = (
dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
)
=======
HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
>>>>>>> c57bd780a (Update conv2d.py)
GRPS = CI // CIG
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/topi/x86/group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from tvm import te
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity

from .util import get_fp32_len
from ..util import get_const_tuple
from .utils import get_fp32_len
from ..utils import get_const_tuple
from ..nn.pad import pad
from .. import tag

from ..nn.util import infer_pad
from ..nn.utils import infer_pad
from ..nn.conv2d import _get_workload as _get_conv2d_workload


Expand Down Expand Up @@ -63,8 +63,8 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype,
def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
pad_left, pad_right = wkl.padl, wkl.padr
stride_w = wkl.wstride
out_width = (wkl.width + pad_left + pad_right - wkl.wkernel) // stride_w + 1
stride_w = wkl.stride_w
out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1
groups = wkl.groups
kernels_per_group = wkl.out_filter // groups
kernel_depth = wkl.in_filter // groups
Expand Down

0 comments on commit 3cf4db1

Please sign in to comment.