Skip to content

Commit

Permalink
[TOPI, Relay] Support roi_align NHWC layout (#7463)
Browse files Browse the repository at this point in the history
* begin nhwc roi align

* integrate mode change from upstream

* adding test

* support nhwc shape func

* update strategy

* refactoring test

* refactor test

* refactoring

* fix lint

* update relay op tests
  • Loading branch information
masahi authored Feb 18, 2021
1 parent 143c88e commit b7e0cfb
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 118 deletions.
20 changes: 14 additions & 6 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,20 @@ def roi_align_strategy_cuda(attrs, inputs, out_type, target):
"""roi_align cuda strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.cuda.schedule_roi_align),
name="roi_align_nchw.cuda",
)

if layout == "NCHW":
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.cuda.schedule_roi_align),
name="roi_align_nchw.cuda",
)
else:
assert layout == "NHWC", "layout must be NCHW or NHWC."
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nhwc),
wrap_topi_schedule(topi.cuda.schedule_roi_align),
name="roi_align_nhwc.cuda",
)
return strategy


Expand Down
20 changes: 13 additions & 7 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,6 @@ def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""

def _compute_roi_align(attrs, inputs, out_type):
assert attrs.layout == "NCHW"
pooled_size = get_const_tuple(attrs.pooled_size)
mode = bytes(attrs.mode, "utf-8")
return [
Expand All @@ -1061,12 +1060,19 @@ def roi_align_strategy(attrs, inputs, out_type, target):
"""roi_align generic strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.generic",
)
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.generic",
)
else:
assert layout == "NHWC", "layout must be NCHW or NHWC."
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nhwc),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.generic",
)
return strategy


Expand Down
19 changes: 13 additions & 6 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,19 @@ def roi_align_strategy_cpu(attrs, inputs, out_type, target):
"""roi_align x86 strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.x86.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.x86",
)
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_roi_align(topi.x86.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.x86",
)
else:
assert layout == "NHWC", "layout must be NCHW or NHWC."
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nhwc),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.x86",
)
return strategy


Expand Down
17 changes: 15 additions & 2 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def nms_shape_func(attrs, inputs, _):


@script
def _roi_align_shape_func(data_shape, rois_shape, pooled_size):
def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
out[0] = rois_shape[0]
out[1] = data_shape[1]
Expand All @@ -95,6 +95,19 @@ def _roi_align_shape_func(data_shape, rois_shape, pooled_size):
return out


@script
def _roi_align_shape_func_nhwc(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
out[0] = rois_shape[0]
out[1] = int64(pooled_size[0])
out[2] = int64(pooled_size[1])
out[3] = data_shape[3]
return out


@reg.register_shape_func("vision.roi_align", False)
def roi_align_shape_func(attrs, inputs, _):
return [_roi_align_shape_func(inputs[0], inputs[1], convert(attrs.pooled_size))]
if attrs.layout == "NCHW":
return [_roi_align_shape_func_nchw(inputs[0], inputs[1], convert(attrs.pooled_size))]
assert attrs.layout == "NHWC", "layout must be NCHW or NHWC."
return [_roi_align_shape_func_nhwc(inputs[0], inputs[1], convert(attrs.pooled_size))]
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .bilinear_resize_python import bilinear_resize_python
from .trilinear_resize3d_python import trilinear_resize3d_python
from .reorg_python import reorg_python
from .roi_align_python import roi_align_nchw_python
from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python
from .roi_pool_python import roi_pool_nchw_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
Expand Down
153 changes: 117 additions & 36 deletions python/tvm/topi/testing/roi_align_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,51 @@
import numpy as np


def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"):
"""Roi align in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode."
_, channel, height, width = a_np.shape
num_roi = rois_np.shape[0]
b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype)
if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size

def _bilinear(n, c, y, x):
if y < -1 or y > height or x < -1 or x > width:
return 0
def _bilinear(a_np, n, c, y, x, height, width, layout):
if y < -1 or y > height or x < -1 or x > width:
return 0

y = min(max(y, 0), height - 1)
x = min(max(x, 0), width - 1)
y = min(max(y, 0), height - 1)
x = min(max(x, 0), width - 1)

y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1
y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1

wy_h = y - y_low
wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h
wy_h = y - y_low
wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h

val = 0
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
val = 0
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
if layout == "NCHW":
val += wx * wy * a_np[n, c, yp, xp]
return val
else:
val += wx * wy * a_np[n, yp, xp, c]
return val


def roi_align_common(
a_np,
b_np,
rois_np,
channel,
pooled_size_h,
pooled_size_w,
spatial_scale,
sample_ratio,
avg_mode,
max_mode,
height,
width,
layout,
):
"""Common code used by roi align NCHW and NHWC"""
num_roi = rois_np.shape[0]

for i in range(num_roi):
roi = rois_np[i]
Expand All @@ -70,8 +79,8 @@ def _bilinear(n, c, y, x):
if sample_ratio > 0:
roi_bin_grid_h = roi_bin_grid_w = int(sample_ratio)
else:
roi_bin_grid_h = int(math.ceil(roi_h / pooled_size))
roi_bin_grid_w = int(math.ceil(roi_w / pooled_size))
roi_bin_grid_h = int(math.ceil(roi_h / pooled_size_h))
roi_bin_grid_w = int(math.ceil(roi_w / pooled_size_w))

count = roi_bin_grid_h * roi_bin_grid_w

Expand All @@ -87,8 +96,80 @@ def _bilinear(n, c, y, x):
y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
if avg_mode:
total += _bilinear(batch_index, c, y, x) / count
total += (
_bilinear(a_np, batch_index, c, y, x, height, width, layout)
/ count
)
if max_mode:
total = max(total, _bilinear(batch_index, c, y, x))
b_np[i, c, ph, pw] = total
total = max(
total,
_bilinear(a_np, batch_index, c, y, x, height, width, layout),
)

if layout == "NCHW":
b_np[i, c, ph, pw] = total
else:
b_np[i, ph, pw, c] = total
return b_np


def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"):
"""Roi align NCHW in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode."
_, channel, height, width = a_np.shape
if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size

b_np = np.zeros((rois_np.shape[0], channel, pooled_size_h, pooled_size_w), dtype=a_np.dtype)

return roi_align_common(
a_np,
b_np,
rois_np,
channel,
pooled_size_h,
pooled_size_w,
spatial_scale,
sample_ratio,
avg_mode,
max_mode,
height,
width,
"NCHW",
)


def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"):
"""Roi align NHWC in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode."
_, height, width, channel = a_np.shape
num_roi = rois_np.shape[0]

if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size

b_np = np.zeros((num_roi, pooled_size_h, pooled_size_w, channel), dtype=a_np.dtype)

return roi_align_common(
a_np,
b_np,
rois_np,
channel,
pooled_size_h,
pooled_size_w,
spatial_scale,
sample_ratio,
avg_mode,
max_mode,
height,
width,
"NHWC",
)
Loading

0 comments on commit b7e0cfb

Please sign in to comment.