Skip to content

Commit

Permalink
[Relay][TOPI] Misc fixes for depthwise conv2d Mali/Bifrost.
Browse files Browse the repository at this point in the history
- Fix assert for Bifrost.
- Set reasonable default axis splits to avoid using tophub.
- Fixed typo: arm cpu -> Mali.
  • Loading branch information
Anastasia Stulova committed Aug 17, 2021
1 parent 0b1ff16 commit 8802b8c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
name="depthwise_conv2d_nchw.bifrost",
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
assert kernel_layout == "HWOI"
# For now just reuse general Mali strategy.
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc),
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/topi/mali/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty
return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)


# register customized schedule for arm cpu.
# register customized schedule for Mali.
@autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali")
def schedule_depthwise_conv2d_nchw(cfg, outs):
"""Schedule depthwise conv2d
Expand Down Expand Up @@ -70,7 +70,7 @@ def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dty
return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)


# register customized schedule for arm cpu.
# register customized schedule for Mali.
@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.mali")
def schedule_depthwise_conv2d_nhwc(cfg, outs):
"""Schedule depthwise conv2d
Expand Down Expand Up @@ -124,8 +124,13 @@ def _schedule(cfg, s, pad_data, kernel, conv, layout):

# fallback support
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "depthwise_conv2d_nchw.mali")
cfg.fallback_with_reference_log(ref_log)
if layout == "NCHW":
ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "depthwise_conv2d_nchw.mali")
cfg.fallback_with_reference_log(ref_log)
else:
cfg.fallback_split("tile_c", [-1, 4, 2])
cfg.fallback_split("tile_y", [-1, 4, 2])
cfg.fallback_split("tile_x", [-1, 4, 2])
###### space definition end ######

# schedule padding
Expand Down

0 comments on commit 8802b8c

Please sign in to comment.