Skip to content

Commit

Permalink
[Relay][Topi] Add max mode to ROI align (apache#7440)
Browse files Browse the repository at this point in the history
* ROI align with max on cpu passes

* onnx test file was not running gpu testsgit status!

* all passing

* fix lint

* lint again

* lint

* lint

* typo

* remove import

* fix import

* add inf, -inf to hybridscript and respond to comments

* shorten code

* make atol lower
  • Loading branch information
electriclilies authored and trevor-m committed Mar 2, 2021
1 parent 16baac5 commit 922a5a1
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 40 deletions.
3 changes: 3 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
double spatial_scale;
int sample_ratio;
std::string layout;
std::string mode;
TVM_DECLARE_ATTRS(ROIAlignAttrs, "relay.attrs.ROIAlignAttrs") {
TVM_ATTR_FIELD(pooled_size).describe("Output size of roi align.");
TVM_ATTR_FIELD(spatial_scale)
Expand All @@ -139,6 +140,8 @@ struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(mode).set_default("avg").describe(
"Mode for ROI Align. Can be 'avg' or 'max'. The default mode is 'avg'.");
}
};

Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,7 @@ def expand_shape(in_shape, shape):
"""
in_dims = infer_shape(in_shape)[0]
new_dims = infer_shape(shape)[0]

if in_dims < new_dims:
in_shape = _op.concatenate(
[
Expand Down Expand Up @@ -2084,8 +2085,8 @@ def _impl_v1(cls, inputs, attr, params):
rois = inputs[1]
batch_indices = inputs[2]
mode = attr.get("mode", b"avg")
if mode != b"avg":
raise ValueError("RoiAlign in Relay only uses avg mode")
if mode not in (b"avg", b"max"):
raise ValueError("RoiAlign in Relay only uses avg and max modes")
output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)

Expand All @@ -2097,7 +2098,7 @@ def _impl_v1(cls, inputs, attr, params):
rois = _op.concatenate([batch_indices, rois], 1)

return _vision.roi_align(
x, rois, [output_height, output_width], spatial_scale, sampling_ratio
x, rois, [output_height, output_width], spatial_scale, sampling_ratio, mode=mode
)


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,13 +1041,15 @@ def wrap_compute_roi_align(topi_compute):
def _compute_roi_align(attrs, inputs, out_type):
assert attrs.layout == "NCHW"
pooled_size = get_const_tuple(attrs.pooled_size)
mode = bytes(attrs.mode, "utf-8")
return [
topi_compute(
inputs[0],
inputs[1],
pooled_size=pooled_size,
spatial_scale=attrs.spatial_scale,
sample_ratio=attrs.sample_ratio,
mode=mode,
)
]

Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/vision/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from . import _make


def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="NCHW"):
def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="NCHW", mode="avg"):
"""ROI align operator.
Parameters
Expand All @@ -40,12 +40,15 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="N
sample_ratio : int
Optional sampling ratio of ROI align, using adaptive size by default.
mode : str, Optional
The pooling method. Relay supports two methods, 'avg' and 'max'. Default is 'avg'.
Returns
-------
output : relay.Expr
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout, mode)


def roi_pool(data, rois, pooled_size, spatial_scale, layout="NCHW"):
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,17 @@ def max_num_threads(func_id, args):
_internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
res = Target.current(args[0].value).max_num_threads
return convert(res)


def inf(func_id, args):
"""Infinity"""
_internal_assert(func_id == "inf", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 1, "One argument accepted!")
return tvm.tir.max_value(args[0])


def ninf(func_id, args):
"""Negative infinity"""
_internal_assert(func_id == "ninf", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 1, "One argument accepted!")
return tvm.tir.min_value(args[0])
10 changes: 10 additions & 0 deletions python/tvm/te/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def max_num_threads(allow_none=True):
return Target.current(allow_none).max_num_threads


def inf(dtype):
return numpy.iinfo(dtype).max


def ninf(dtype):
return numpy.iinfo(dtype).min


HYBRID_GLOBALS = {
"unroll": range,
"vectorize": range,
Expand Down Expand Up @@ -142,6 +150,8 @@ def max_num_threads(allow_none=True):
"float64": numpy.float64,
"ceil_div": lambda a, b: (a + b - 1) // b,
"max_num_threads": max_num_threads,
"inf": inf,
"ninf": inf,
}


Expand Down
18 changes: 13 additions & 5 deletions python/tvm/topi/testing/roi_align_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import numpy as np


def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio):
def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"):
"""Roi align in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode."
_, channel, height, width = a_np.shape
num_roi = rois_np.shape[0]
b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype)

if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
Expand Down Expand Up @@ -76,11 +78,17 @@ def _bilinear(n, c, y, x):
for c in range(channel):
for ph in range(pooled_size_h):
for pw in range(pooled_size_w):
total = 0.0
if avg_mode:
total = 0.0
if max_mode:
total = float("-inf")
for iy in range(roi_bin_grid_h):
for ix in range(roi_bin_grid_w):
y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
total += _bilinear(batch_index, c, y, x)
b_np[i, c, ph, pw] = total / count
if avg_mode:
total += _bilinear(batch_index, c, y, x) / count
if max_mode:
total = max(total, _bilinear(batch_index, c, y, x))
b_np[i, c, ph, pw] = total
return b_np
26 changes: 22 additions & 4 deletions python/tvm/topi/vision/rcnn/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...cpp.utils import bilinear_sample_nchw


def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1):
"""ROI align operator in NCHW layout.
Parameters
Expand All @@ -41,6 +41,10 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]
mode : int or str
There are two modes, average and max. For the average mode, you can pass b'avg' or 0, and
for the max mode, you can pass b'max' or 1.
sample_ratio : int
Optional sampling ratio of ROI align, using adaptive size by default.
Expand All @@ -49,6 +53,9 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
output : tvm.te.Tensor
4-D with shape [num_roi, channel, pooled_size, pooled_size]
"""
avg_mode = mode in (b"avg", 0)
max_mode = mode in (b"max", 1)
assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a valid mode."
dtype = rois.dtype
_, channel, height, width = get_const_tuple(data.shape)
num_roi, _ = get_const_tuple(rois.shape)
Expand Down Expand Up @@ -92,14 +99,25 @@ def _sample(i, c, ph, pw):
rw = te.reduce_axis((0, roi_bin_grid_w))
roi_start_h += ph * bin_h
roi_start_w += pw * bin_w
return te.sum(
if avg_mode:
return te.sum(
_bilinear(
batch_index,
c,
roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h,
roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w,
)
/ count,
axis=[rh, rw],
)
# max mode
return te.max(
_bilinear(
batch_index,
c,
roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h,
roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w,
)
/ count,
),
axis=[rh, rw],
)

Expand Down
47 changes: 37 additions & 10 deletions python/tvm/topi/x86/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements
"""Non-maximum suppression operator for intel cpu"""
import math
import tvm

import tvm
from tvm.te import hybrid
from ..tensor import full
from ..utils import get_const_tuple


@hybrid.script
def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio):
def roi_align_nchw_ir(
data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio, mode
):
"""Hybrid routing fo ROI align operator in NCHW layout.
Parameters
Expand Down Expand Up @@ -57,6 +59,10 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s
sample_ratio : tvm.tir.const
Sampling ratio of ROI align, using adaptive size by default.
mode : tvm.tir.const
Mode of RoiAlign. A value of 0 corrensponds to b'avg', while a value of 1 corresponds to
b'max'.
Returns
-------
output : tvm.te.Tensor or numpy NDArray
Expand Down Expand Up @@ -160,10 +166,12 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s
pre_calc_index = 0
for ph in range(pooled_size_h):
for pw in range(pooled_size_w):
output_val = 0.0
output_val = 0.0 # Avg mode
if mode == 1: # Max mode
output_val = ninf("float32")
for iy in range(roi_bin_grid_h):
for ix in range(roi_bin_grid_w):
output_val += (
bilinear_val = (
w_pc[n, pre_calc_index, 0]
* data[
roi_batch_index,
Expand Down Expand Up @@ -194,14 +202,15 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s
]
)
pre_calc_index += 1

output_val /= count
output[n, c, ph, pw] = output_val

if mode == 0: # Avg mode
output_val += bilinear_val / count
if mode == 1: # Max mode
output_val = max(output_val, bilinear_val)
output[n, c, ph, pw] = output_val
return output


def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1):
"""ROI align operator in NCHW layout.
Parameters
Expand All @@ -220,6 +229,9 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]
mode : str
Mode of RoiAlign. Should be b'max' or b'avg'.
sample_ratio : int
Optional sampling ratio of ROI align, using adaptive size by default.
Expand Down Expand Up @@ -250,6 +262,21 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
pooled_size = tvm.runtime.convert(pooled_size)
spatial_scale = tvm.tir.const(spatial_scale, "float32")
sample_ratio = tvm.tir.const(sample_ratio, "int32")
if mode in (b"avg", 0):
mode = tvm.tir.const(0, dtype="float32")
elif mode in (b"max", 1):
mode = tvm.tir.const(1, dtype="float32")
else:
raise ValueError(mode, "Value %s passed in for mode not supported", mode)

return roi_align_nchw_ir(
data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio
data,
rois,
num_rois,
w_pc_buffer,
pos_pc_buffer,
pooled_size,
spatial_scale,
sample_ratio,
mode,
)
3 changes: 2 additions & 1 deletion src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ Array<Array<Layout> > ROIAlignInferCorrectLayout(const Attrs& attrs,
}

Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
int sample_ratio, String layout) {
int sample_ratio, String layout, String mode) {
auto attrs = make_object<ROIAlignAttrs>();
attrs->pooled_size = pooled_size;
attrs->spatial_scale = spatial_scale;
attrs->sample_ratio = sample_ratio;
attrs->layout = layout;
attrs->mode = mode;
static const Op& op = Op::Get("vision.roi_align");
return Call(op, {data, rois}, Attrs(attrs), {});
}
Expand Down
12 changes: 10 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3437,15 +3437,21 @@ def verify_topk(input_dims, K, axis=-1):
@tvm.testing.uses_gpu
def test_roi_align():
def verify_roi_align(
input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0
input_dims,
num_roi,
output_height,
output_width,
sampling_ratio=0,
spatial_scale=1.0,
mode="avg",
):
output_dims = [num_roi, input_dims[1], output_height, output_width]

node = helper.make_node(
"RoiAlign",
inputs=["X", "rois", "batch_indicies"],
outputs=["Y"],
mode="avg",
mode=mode,
output_height=output_height,
output_width=output_width,
sampling_ratio=sampling_ratio,
Expand Down Expand Up @@ -3490,6 +3496,8 @@ def verify_roi_align(
verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0)

# ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here.


# @tvm.testing.uses_gpu
def test_non_max_suppression():
Expand Down
Loading

0 comments on commit 922a5a1

Please sign in to comment.