From 8802b8cf0608aeeea574a3c34c73df7c04c0c85e Mon Sep 17 00:00:00 2001 From: Anastasia Stulova Date: Fri, 30 Jul 2021 11:04:33 +0100 Subject: [PATCH] [Relay][TOPI] Misc fixes for depthwise conv2d Mali/Bifrost. - Fix assert for Bifrost. - Set reasonable default axis splits to avoid using tophub. - Fixed typo: arm cpu -> Mali. --- python/tvm/relay/op/strategy/bifrost.py | 2 +- python/tvm/topi/mali/depthwise_conv2d.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index cfc676680b448..ec3edab2c8b19 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -84,7 +84,7 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.bifrost", ) elif layout == "NHWC": - assert kernel_layout == "HWIO" + assert kernel_layout == "HWOI" # For now just reuse general Mali strategy. strategy.add_implementation( wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), diff --git a/python/tvm/topi/mali/depthwise_conv2d.py b/python/tvm/topi/mali/depthwise_conv2d.py index 83c4cc8294b29..b5c47041d4942 100644 --- a/python/tvm/topi/mali/depthwise_conv2d.py +++ b/python/tvm/topi/mali/depthwise_conv2d.py @@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) -# register customized schedule for arm cpu. +# register customized schedule for Mali. @autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali") def schedule_depthwise_conv2d_nchw(cfg, outs): """Schedule depthwise conv2d @@ -70,7 +70,7 @@ def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dty return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) -# register customized schedule for arm cpu. +# register customized schedule for Mali. @autotvm.register_topi_schedule("depthwise_conv2d_nhwc.mali") def schedule_depthwise_conv2d_nhwc(cfg, outs): """Schedule depthwise conv2d @@ -124,8 +124,13 @@ def _schedule(cfg, s, pad_data, kernel, conv, layout): # fallback support if cfg.is_fallback: - ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "depthwise_conv2d_nchw.mali") - cfg.fallback_with_reference_log(ref_log) + if layout == "NCHW": + ref_log = autotvm.tophub.load_reference_log("mali", "rk3399", "depthwise_conv2d_nchw.mali") + cfg.fallback_with_reference_log(ref_log) + else: + cfg.fallback_split("tile_c", [-1, 4, 2]) + cfg.fallback_split("tile_y", [-1, 4, 2]) + cfg.fallback_split("tile_x", [-1, 4, 2]) ###### space definition end ###### # schedule padding