From be0794b9e2d89f2ecc89a1cbe575b3419b45b988 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 9 Feb 2021 14:43:07 -0500 Subject: [PATCH 01/13] ROI align with max on cpu passes --- include/tvm/relay/attrs/vision.h | 4 + python/tvm/relay/frontend/onnx.py | 7 +- python/tvm/relay/op/strategy/generic.py | 2 + python/tvm/relay/op/vision/rcnn.py | 7 +- python/tvm/te/hybrid/parser.py | 2 +- python/tvm/topi/vision/rcnn/roi_align.py | 31 +++-- python/tvm/topi/x86/roi_align.py | 126 ++++++++++++++------- src/relay/op/vision/rcnn_op.cc | 3 +- tests/python/frontend/onnx/test_forward.py | 15 ++- 9 files changed, 137 insertions(+), 60 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index ca2c4a2b837d..e030fb6a689c 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -124,6 +124,7 @@ struct ROIAlignAttrs : public tvm::AttrsNode { double spatial_scale; int sample_ratio; std::string layout; + std::string mode; TVM_DECLARE_ATTRS(ROIAlignAttrs, "relay.attrs.ROIAlignAttrs") { TVM_ATTR_FIELD(pooled_size).describe("Output size of roi align."); TVM_ATTR_FIELD(spatial_scale) @@ -139,6 +140,9 @@ struct ROIAlignAttrs : public tvm::AttrsNode { "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Convolution is applied on the 'H' and" "'W' dimensions."); + TVM_ATTR_FIELD(mode).set_default("avg").describe( + "Mode for ROI Align. Can be 'avg' or 'max'. The default mode is 'avg'." + ); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c423598a2ee7..bd5936cb9bac 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1614,6 +1614,7 @@ def expand_shape(in_shape, shape): """ in_dims = infer_shape(in_shape)[0] new_dims = infer_shape(shape)[0] + if in_dims < new_dims: in_shape = _op.concatenate( [ @@ -2034,8 +2035,8 @@ def _impl_v1(cls, inputs, attr, params): rois = inputs[1] batch_indices = inputs[2] mode = attr.get("mode", b"avg") - if mode != b"avg": - raise ValueError("RoiAlign in Relay only uses avg mode") + if mode != b"avg" and mode != b"max": + raise ValueError("RoiAlign in Relay only uses avg and max modes") output_height = attr.get("output_height", 1) output_width = attr.get("output_width", 1) @@ -2047,7 +2048,7 @@ def _impl_v1(cls, inputs, attr, params): rois = _op.concatenate([batch_indices, rois], 1) return _vision.roi_align( - x, rois, [output_height, output_width], spatial_scale, sampling_ratio + x, rois, [output_height, output_width], spatial_scale, sampling_ratio, mode=mode ) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3ad75faf4bc1..17dcd4e54a46 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1028,6 +1028,7 @@ def wrap_compute_roi_align(topi_compute): def _compute_roi_align(attrs, inputs, out_type): assert attrs.layout == "NCHW" pooled_size = get_const_tuple(attrs.pooled_size) + mode = bytes(attrs.mode, 'utf-8') return [ topi_compute( inputs[0], @@ -1035,6 +1036,7 @@ def _compute_roi_align(attrs, inputs, out_type): pooled_size=pooled_size, spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio, + mode=mode ) ] diff --git a/python/tvm/relay/op/vision/rcnn.py b/python/tvm/relay/op/vision/rcnn.py index b87eb07d7563..ffa9487c4ffa 100644 --- a/python/tvm/relay/op/vision/rcnn.py +++ b/python/tvm/relay/op/vision/rcnn.py @@ -18,7 +18,7 @@ from . import _make -def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="NCHW"): +def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="NCHW", mode="avg"): """ROI align operator. Parameters @@ -39,13 +39,16 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="N sample_ratio : int Optional sampling ratio of ROI align, using adaptive size by default. + + mode : str, Optional + The pooling method. Relay supports two methods, 'avg' and 'max'. Default is 'avg'. Returns ------- output : relay.Expr 4-D tensor with shape [num_roi, channel, pooled_size, pooled_size] """ - return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout) + return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout, mode) def roi_pool(data, rois, pooled_size, spatial_scale, layout="NCHW"): diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 7bb85e3da83c..9d73b1d61943 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -270,7 +270,7 @@ def visit_Name(self, node): # Do I need any assertion here? return entry - def visit_Num(self, node): + def visit_Constant(self, node): if isinstance(node.n, numbers.Integral): dtype = "int32" elif isinstance(node.n, float): diff --git a/python/tvm/topi/vision/rcnn/roi_align.py b/python/tvm/topi/vision/rcnn/roi_align.py index 30824770b7b2..b4cc40102f7a 100644 --- a/python/tvm/topi/vision/rcnn/roi_align.py +++ b/python/tvm/topi/vision/rcnn/roi_align.py @@ -22,7 +22,7 @@ from ...cpp.utils import bilinear_sample_nchw -def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1): +def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1): """ROI align operator in NCHW layout. Parameters @@ -92,17 +92,26 @@ def _sample(i, c, ph, pw): rw = te.reduce_axis((0, roi_bin_grid_w)) roi_start_h += ph * bin_h roi_start_w += pw * bin_w - return te.sum( - _bilinear( - batch_index, - c, - roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, - roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w, + if mode == b'avg': + return te.sum( + _bilinear( + batch_index, + c, + roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, + roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w, + ) + / count, + axis=[rh, rw], + ) + elif mode == b'max': + return te.max( + _bilinear( + batch_index, + c, + roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, + roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w), + axis=[rh, rw] ) - / count, - axis=[rh, rw], - ) - return te.compute( (num_roi, channel, pooled_size_h, pooled_size_w), _sample, tag="pool,roi_align_nchw" ) diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index ac2146b558f9..e98c34c0acec 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -19,13 +19,14 @@ import math import tvm +from tvm import relay from tvm.te import hybrid from ..tensor import full from ..utils import get_const_tuple @hybrid.script -def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio): +def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio, mode): """Hybrid routing fo ROI align operator in NCHW layout. Parameters @@ -57,6 +58,10 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s sample_ratio : tvm.tir.const Sampling ratio of ROI align, using adaptive size by default. + mode : tvm.tir.const + Mode of RoiAlign. A value of 0 corrensponds to b'avg', while a value of 1 corresponds to + b'max'. + Returns ------- output : tvm.te.Tensor or numpy NDArray @@ -161,47 +166,81 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s for ph in range(pooled_size_h): for pw in range(pooled_size_w): output_val = 0.0 - for iy in range(roi_bin_grid_h): - for ix in range(roi_bin_grid_w): - output_val += ( - w_pc[n, pre_calc_index, 0] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 0], - ] - + w_pc[n, pre_calc_index, 1] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 1], - ] - + w_pc[n, pre_calc_index, 2] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 0], - ] - + w_pc[n, pre_calc_index, 3] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 1], - ] - ) - pre_calc_index += 1 - - output_val /= count - output[n, c, ph, pw] = output_val - + if mode == 0: + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + output_val += ( + w_pc[n, pre_calc_index, 0] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 0], + ] + + w_pc[n, pre_calc_index, 1] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 1], + ] + + w_pc[n, pre_calc_index, 2] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 0], + ] + + w_pc[n, pre_calc_index, 3] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 1], + ] + ) + pre_calc_index += 1 + + output_val /= count + output[n, c, ph, pw] = output_val + elif mode == 1: + output_val = 0.0 + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + output_val = max(output_val, w_pc[n, pre_calc_index, 0] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 0], + ]) + output_val = max(output_val, w_pc[n, pre_calc_index, 1] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 1], + ]) + output_val = max(output_val, w_pc[n, pre_calc_index, 2] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 0], + ]) + output_val = max(output_val, w_pc[n, pre_calc_index, 3] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 1], + ]) + pre_calc_index += 1 + output[n, c, ph, pw] = output_val return output -def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1): +def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1): """ROI align operator in NCHW layout. Parameters @@ -219,6 +258,9 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1): spatial_scale : float Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers, which should be in range (0.0, 1.0] + + mode : str + Mode of RoiAlign. Should be b'max' or b'avg'. sample_ratio : int Optional sampling ratio of ROI align, using adaptive size by default. @@ -250,6 +292,10 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1): pooled_size = tvm.runtime.convert(pooled_size) spatial_scale = tvm.tir.const(spatial_scale, "float32") sample_ratio = tvm.tir.const(sample_ratio, "int32") + if mode == b'avg': + mode = tvm.tir.const(0, dtype='float32') + elif mode == b'max': + mode = tvm.tir.const(1, dtype='float32') return roi_align_nchw_ir( - data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio + data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio, mode ) diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index f7bbf378d09c..c899681733f8 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -76,12 +76,13 @@ Array > ROIAlignInferCorrectLayout(const Attrs& attrs, } Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, - int sample_ratio, String layout) { + int sample_ratio, String layout, String mode) { auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; attrs->sample_ratio = sample_ratio; attrs->layout = layout; + attrs->mode = mode; static const Op& op = Op::Get("vision.roi_align"); return Call(op, {data, rois}, Attrs(attrs), {}); } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 515fc32ef88d..7b11c8dfa53f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3437,7 +3437,7 @@ def verify_topk(input_dims, K, axis=-1): @tvm.testing.uses_gpu def test_roi_align(): def verify_roi_align( - input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0 + input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0, mode="avg" ): output_dims = [num_roi, input_dims[1], output_height, output_width] @@ -3445,7 +3445,7 @@ def verify_roi_align( "RoiAlign", inputs=["X", "rois", "batch_indicies"], outputs=["Y"], - mode="avg", + mode=mode, output_height=output_height, output_width=output_width, sampling_ratio=sampling_ratio, @@ -3489,6 +3489,17 @@ def verify_roi_align( verify_roi_align((3, 4, 12, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.5) verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) + + verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") + verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") + verify_roi_align((1, 8, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") + verify_roi_align((1, 4, 8, 8), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") + verify_roi_align((1, 4, 16, 16), 16, 5, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") + verify_roi_align((1, 4, 16, 12), 8, 7, 3, sampling_ratio=0, spatial_scale=1.0, mode="max") + verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=0.5, mode="max") + verify_roi_align((3, 4, 12, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.5, mode="max") + verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0, mode="max") + verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0, mode="max") # @tvm.testing.uses_gpu From 6a1c6a0488784fd854da189912e367ccbee44b7d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 10 Feb 2021 10:32:33 -0800 Subject: [PATCH 02/13] onnx test file was not running gpu testsgit status! --- python/tvm/te/hybrid/parser.py | 2 +- python/tvm/topi/testing/roi_align_python.py | 38 +++-- python/tvm/topi/vision/rcnn/roi_align.py | 13 +- python/tvm/topi/x86/roi_align.py | 8 +- tests/python/frontend/onnx/test_forward.py | 142 +++++++++---------- tests/python/topi/python/test_topi_vision.py | 37 +++-- 6 files changed, 141 insertions(+), 99 deletions(-) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 9d73b1d61943..7bb85e3da83c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -270,7 +270,7 @@ def visit_Name(self, node): # Do I need any assertion here? return entry - def visit_Constant(self, node): + def visit_Num(self, node): if isinstance(node.n, numbers.Integral): dtype = "int32" elif isinstance(node.n, float): diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index abef25f0b994..13e76aa2c870 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -20,12 +20,18 @@ import numpy as np -def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio): +def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b'avg'): """Roi align in python""" + avg_mode = (mode == b'avg' or mode == 0) + max_mode = (mode == b'max' or mode == 1) + assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode." _, channel, height, width = a_np.shape num_roi = rois_np.shape[0] b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype) - + if (avg_mode): + print("average mode") + if (max_mode): + print(max_mode) if isinstance(pooled_size, int): pooled_size_h = pooled_size_w = pooled_size else: @@ -52,7 +58,10 @@ def _bilinear(n, c, y, x): for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): if 0 <= yp < height and 0 <= xp < width: - val += wx * wy * a_np[n, c, yp, xp] + if avg_mode: + val += wx * wy * a_np[n, c, yp, xp] + elif max_mode: + val = max(val, wx * wy * a_np[n, c, yp, xp]) return val for i in range(num_roi): @@ -76,11 +85,20 @@ def _bilinear(n, c, y, x): for c in range(channel): for ph in range(pooled_size_h): for pw in range(pooled_size_w): - total = 0.0 - for iy in range(roi_bin_grid_h): - for ix in range(roi_bin_grid_w): - y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h - x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w - total += _bilinear(batch_index, c, y, x) - b_np[i, c, ph, pw] = total / count + if avg_mode: + total = 0.0 + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h + x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w + total += _bilinear(batch_index, c, y, x) + b_np[i, c, ph, pw] = total / count + elif max_mode: + total = 0.0 + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h + x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w + total = max(total, _bilinear(batch_index, c, y, x)) + b_np[i, c, ph, pw] = total return b_np diff --git a/python/tvm/topi/vision/rcnn/roi_align.py b/python/tvm/topi/vision/rcnn/roi_align.py index b4cc40102f7a..fbae4dab5fc4 100644 --- a/python/tvm/topi/vision/rcnn/roi_align.py +++ b/python/tvm/topi/vision/rcnn/roi_align.py @@ -41,6 +41,10 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers, which should be in range (0.0, 1.0] + mode : int or str + There are two modes, average and max. For the average mode, you can pass b'avg' or 0, and + for the max mode, you can pass b'max' or 1. + sample_ratio : int Optional sampling ratio of ROI align, using adaptive size by default. @@ -49,6 +53,11 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 output : tvm.te.Tensor 4-D with shape [num_roi, channel, pooled_size, pooled_size] """ + print("rcnn roi align called") + print("rcnn roi align mode: ", mode) + avg_mode = (mode == b'avg' or mode == 0) + max_mode = (mode == b'max' or mode == 1) + assert (avg_mode or max_mode), "Mode must be avg or max. Please pass in a valid mode." dtype = rois.dtype _, channel, height, width = get_const_tuple(data.shape) num_roi, _ = get_const_tuple(rois.shape) @@ -92,7 +101,7 @@ def _sample(i, c, ph, pw): rw = te.reduce_axis((0, roi_bin_grid_w)) roi_start_h += ph * bin_h roi_start_w += pw * bin_w - if mode == b'avg': + if avg_mode: return te.sum( _bilinear( batch_index, @@ -103,7 +112,7 @@ def _sample(i, c, ph, pw): / count, axis=[rh, rw], ) - elif mode == b'max': + elif max_mode: return te.max( _bilinear( batch_index, diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index e98c34c0acec..8e229e640167 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -270,6 +270,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 output : tvm.te.Tensor 4-D with shape [num_roi, channel, pooled_size, pooled_size] """ + print("x86 roi align") if not isinstance(pooled_size, (tuple, list)): pooled_size = (pooled_size, pooled_size) @@ -292,10 +293,13 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 pooled_size = tvm.runtime.convert(pooled_size) spatial_scale = tvm.tir.const(spatial_scale, "float32") sample_ratio = tvm.tir.const(sample_ratio, "int32") - if mode == b'avg': + if mode == b'avg' or mode == 0: mode = tvm.tir.const(0, dtype='float32') - elif mode == b'max': + elif mode == b'max' or mode == 1: mode = tvm.tir.const(1, dtype='float32') + else: + raise ValueError(mode, "Value %s passed in for mode not supported", mode) + print("Mode from x86 roi align is: ", mode) return roi_align_nchw_ir( data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio, mode ) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7b11c8dfa53f..1aa97ae7fa21 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -67,15 +67,13 @@ def get_tvm_output( graph_def, input_data, target, ctx, output_shape=None, output_dtype="float32", opset=None ): """ Generic function to execute and get tvm output""" - target = "llvm" input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) - with tvm.transform.PassContext(opt_level=1): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) - ctx = tvm.cpu(0) m = graph_runtime.create(graph, lib, ctx) # set inputs if isinstance(input_data, list): @@ -141,6 +139,7 @@ def verify_with_ort_with_inputs( targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] for target in targets: + print("target: ", target) ctx = tvm.context(target, 0) if use_vm: tvm_out = get_tvm_output_with_vm( @@ -3976,76 +3975,77 @@ def verify_softplus(indata): if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_matmul() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_unary_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() - test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() + # test_flatten() + # test_reshape() + # test_shape() + # test_expand() + # test_power() + # test_squeeze() + # test_unsqueeze() + # test_slice() + # test_floor() + # test_ceil() + # test_round() + # test_isinf() + # test_isnan() + # test_clip() + # test_clip_min_max_as_inputs() + # test_onehot() + # test_matmul() + # test_gather() + # test_gatherelements() + # test_gather_nd() + # test_scatter() + # test_lrn() + # test_instance_norm() + # test_upsample() + # test_forward_min() + # test_forward_max() + # test_forward_mean() + # test_forward_hardsigmoid() + # test_forward_arg_min_max() + # test_softmax() + # test_constantofshape() + # test_all_reduce_funcs() + # test_pad() + # test_split() + # test_binary_ops() + # test_unary_ops() + # test_leaky_relu() + # test_elu() + # test_selu() + # test_prelu() + # test_ThresholdedRelu() + # test_LogSoftmax() + # test_resnet() + # test_inception() + # test_densenet() + # test_sign() + # test_not() + # test_and() + # test_tile() + # test_erf() + # test_where() + # test_or() + # test_depth_to_space() + # test_space_to_depth() + # test_batch_norm() + # test_batch_norm_dynamic_subgraph() + # test_conv() + # test_convtranspose() + # test_unsqueeze_constant() + # test_pooling() + # test_lppool() + # test_lstm() + # test_gru() + # test_resize() + # test_nonzero() + # test_topk() + # test_mod() + # test_xor() test_max_roi_pool() test_roi_align() + exit() test_range() test_loop() test_size() diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 697ef8a24f67..5ec289ec6390 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -418,7 +418,7 @@ def check_device(device): check_device(device) -def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio): +def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio, mode): # 0 = avg, 1 = max a_shape = (batch, in_channel, in_size, in_size) rois_shape = (num_roi, 5) @@ -430,12 +430,16 @@ def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype("float32") rois_np = np.random.uniform(size=rois_shape).astype("float32") * in_size rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi) + print(type(a_np)) + print(type(sample_ratio)) + print("type mode: ", type(mode)) b_np = tvm.topi.testing.roi_align_nchw_python( a_np, rois_np, pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=sample_ratio, + mode=mode ) return a_np, rois_np, b_np @@ -447,8 +451,7 @@ def check_device(device): if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return - print("Running on target: %s" % device) - + print("Mode in check device: ", mode) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _roi_align_implement) b = fcompute( @@ -457,6 +460,7 @@ def check_device(device): pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=sample_ratio, + mode=mode ) s = fschedule(b) @@ -465,7 +469,10 @@ def check_device(device): tvm_b = tvm.nd.array(np.zeros(get_const_tuple(b.shape), dtype=b.dtype), ctx=ctx) f = tvm.build(s, [a, rois, b], device) f(tvm_a, tvm_rois, tvm_b) - tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3) + tvm_val = tvm_b.asnumpy() + #print("Tvm val: ", tvm_val) + #print("B_np: ", b_np) + tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3) for device in ["llvm", "cuda", "opencl"]: check_device(device) @@ -473,10 +480,14 @@ def check_device(device): @tvm.testing.uses_gpu def test_roi_align(): - verify_roi_align(1, 16, 32, 64, 7, 1.0, -1) - verify_roi_align(4, 16, 32, 64, 7, 0.5, 2) - verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2) - verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2) + verify_roi_align(1, 16, 32, 64, 7, 1.0, -1, 0) + verify_roi_align(4, 16, 32, 64, 7, 0.5, 2, 0) + verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2, 0) + verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2, 0) + verify_roi_align(1, 16, 32, 64, 7, 1.0, -1, 1) + verify_roi_align(4, 16, 32, 64, 7, 0.5, 2, 1) + verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2, 1) + verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2, 1) def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale): @@ -617,10 +628,10 @@ def test_proposal(): if __name__ == "__main__": - test_get_valid_counts() - test_multibox_prior() - test_multibox_detection() + #test_get_valid_counts() + #test_multibox_prior() + #test_multibox_detection() test_roi_align() test_roi_pool() - test_proposal() - test_non_max_suppression() + #test_proposal() + #test_non_max_suppression() From 79252bd71d3966bd92299b7e6d41aae9f5229377 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 10 Feb 2021 14:53:29 -0800 Subject: [PATCH 03/13] all passing --- python/tvm/relay/op/strategy/generic.py | 4 +- python/tvm/topi/testing/roi_align_python.py | 15 +- python/tvm/topi/vision/rcnn/roi_align.py | 14 +- python/tvm/topi/x86/roi_align.py | 87 +++++----- tests/python/frontend/onnx/test_forward.py | 163 +++++++++---------- tests/python/relay/test_op_level5.py | 18 +- tests/python/topi/python/test_topi_vision.py | 24 ++- 7 files changed, 167 insertions(+), 158 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 17dcd4e54a46..f91a5824c031 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1028,7 +1028,7 @@ def wrap_compute_roi_align(topi_compute): def _compute_roi_align(attrs, inputs, out_type): assert attrs.layout == "NCHW" pooled_size = get_const_tuple(attrs.pooled_size) - mode = bytes(attrs.mode, 'utf-8') + mode = bytes(attrs.mode, "utf-8") return [ topi_compute( inputs[0], @@ -1036,7 +1036,7 @@ def _compute_roi_align(attrs, inputs, out_type): pooled_size=pooled_size, spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio, - mode=mode + mode=mode, ) ] diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index 13e76aa2c870..ad8dc09abe24 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -20,18 +20,14 @@ import numpy as np -def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b'avg'): +def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"): """Roi align in python""" - avg_mode = (mode == b'avg' or mode == 0) - max_mode = (mode == b'max' or mode == 1) + avg_mode = mode == b"avg" or mode == "avg" or mode == 0 + max_mode = mode == b"max" or mode == "max" or mode == 1 assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode." _, channel, height, width = a_np.shape num_roi = rois_np.shape[0] b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype) - if (avg_mode): - print("average mode") - if (max_mode): - print(max_mode) if isinstance(pooled_size, int): pooled_size_h = pooled_size_w = pooled_size else: @@ -58,10 +54,7 @@ def _bilinear(n, c, y, x): for wx, xp in zip((wx_l, wx_h), (x_low, x_high)): for wy, yp in zip((wy_l, wy_h), (y_low, y_high)): if 0 <= yp < height and 0 <= xp < width: - if avg_mode: - val += wx * wy * a_np[n, c, yp, xp] - elif max_mode: - val = max(val, wx * wy * a_np[n, c, yp, xp]) + val += wx * wy * a_np[n, c, yp, xp] return val for i in range(num_roi): diff --git a/python/tvm/topi/vision/rcnn/roi_align.py b/python/tvm/topi/vision/rcnn/roi_align.py index fbae4dab5fc4..324a71f3bdf4 100644 --- a/python/tvm/topi/vision/rcnn/roi_align.py +++ b/python/tvm/topi/vision/rcnn/roi_align.py @@ -53,11 +53,9 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 output : tvm.te.Tensor 4-D with shape [num_roi, channel, pooled_size, pooled_size] """ - print("rcnn roi align called") - print("rcnn roi align mode: ", mode) - avg_mode = (mode == b'avg' or mode == 0) - max_mode = (mode == b'max' or mode == 1) - assert (avg_mode or max_mode), "Mode must be avg or max. Please pass in a valid mode." + avg_mode = mode == b"avg" or mode == 0 + max_mode = mode == b"max" or mode == 1 + assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a valid mode." dtype = rois.dtype _, channel, height, width = get_const_tuple(data.shape) num_roi, _ = get_const_tuple(rois.shape) @@ -118,9 +116,11 @@ def _sample(i, c, ph, pw): batch_index, c, roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, - roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w), - axis=[rh, rw] + roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w, + ), + axis=[rh, rw], ) + return te.compute( (num_roi, channel, pooled_size_h, pooled_size_w), _sample, tag="pool,roi_align_nchw" ) diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index 8e229e640167..738be505acb3 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -26,7 +26,9 @@ @hybrid.script -def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio, mode): +def roi_align_nchw_ir( + data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio, mode +): """Hybrid routing fo ROI align operator in NCHW layout. Parameters @@ -204,38 +206,40 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s output_val /= count output[n, c, ph, pw] = output_val elif mode == 1: - output_val = 0.0 for iy in range(roi_bin_grid_h): for ix in range(roi_bin_grid_w): - output_val = max(output_val, w_pc[n, pre_calc_index, 0] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 0], - ]) - output_val = max(output_val, w_pc[n, pre_calc_index, 1] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 1], - ]) - output_val = max(output_val, w_pc[n, pre_calc_index, 2] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 0], - ]) - output_val = max(output_val, w_pc[n, pre_calc_index, 3] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 1], - ]) + bilinear_val = ( + w_pc[n, pre_calc_index, 0] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 0], + ] + + w_pc[n, pre_calc_index, 1] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 1], + ] + + w_pc[n, pre_calc_index, 2] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 0], + ] + + w_pc[n, pre_calc_index, 3] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 1], + ] + ) pre_calc_index += 1 + output_val = max(output_val, bilinear_val) output[n, c, ph, pw] = output_val return output @@ -258,7 +262,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 spatial_scale : float Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal of total stride in convolutional layers, which should be in range (0.0, 1.0] - + mode : str Mode of RoiAlign. Should be b'max' or b'avg'. @@ -270,7 +274,6 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 output : tvm.te.Tensor 4-D with shape [num_roi, channel, pooled_size, pooled_size] """ - print("x86 roi align") if not isinstance(pooled_size, (tuple, list)): pooled_size = (pooled_size, pooled_size) @@ -293,13 +296,21 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 pooled_size = tvm.runtime.convert(pooled_size) spatial_scale = tvm.tir.const(spatial_scale, "float32") sample_ratio = tvm.tir.const(sample_ratio, "int32") - if mode == b'avg' or mode == 0: - mode = tvm.tir.const(0, dtype='float32') - elif mode == b'max' or mode == 1: - mode = tvm.tir.const(1, dtype='float32') + if mode == b"avg" or mode == 0: + mode = tvm.tir.const(0, dtype="float32") + elif mode == b"max" or mode == 1: + mode = tvm.tir.const(1, dtype="float32") else: raise ValueError(mode, "Value %s passed in for mode not supported", mode) - print("Mode from x86 roi align is: ", mode) + return roi_align_nchw_ir( - data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio, mode + data, + rois, + num_rois, + w_pc_buffer, + pos_pc_buffer, + pooled_size, + spatial_scale, + sample_ratio, + mode, ) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1aa97ae7fa21..3dc8472c3443 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -67,13 +67,15 @@ def get_tvm_output( graph_def, input_data, target, ctx, output_shape=None, output_dtype="float32", opset=None ): """ Generic function to execute and get tvm output""" + target = "llvm" input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=1): graph, lib, params = relay.build(mod, target, params=params) + ctx = tvm.cpu(0) m = graph_runtime.create(graph, lib, ctx) # set inputs if isinstance(input_data, list): @@ -139,7 +141,6 @@ def verify_with_ort_with_inputs( targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] for target in targets: - print("target: ", target) ctx = tvm.context(target, 0) if use_vm: tvm_out = get_tvm_output_with_vm( @@ -3436,7 +3437,13 @@ def verify_topk(input_dims, K, axis=-1): @tvm.testing.uses_gpu def test_roi_align(): def verify_roi_align( - input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0, mode="avg" + input_dims, + num_roi, + output_height, + output_width, + sampling_ratio=0, + spatial_scale=1.0, + mode="avg", ): output_dims = [num_roi, input_dims[1], output_height, output_width] @@ -3488,17 +3495,8 @@ def verify_roi_align( verify_roi_align((3, 4, 12, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.5) verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) - - verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") - verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") - verify_roi_align((1, 8, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") - verify_roi_align((1, 4, 8, 8), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") - verify_roi_align((1, 4, 16, 16), 16, 5, 7, sampling_ratio=0, spatial_scale=1.0, mode="max") - verify_roi_align((1, 4, 16, 12), 8, 7, 3, sampling_ratio=0, spatial_scale=1.0, mode="max") - verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=0.5, mode="max") - verify_roi_align((3, 4, 12, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.5, mode="max") - verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0, mode="max") - verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0, mode="max") + + # ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here. # @tvm.testing.uses_gpu @@ -3975,77 +3973,76 @@ def verify_softplus(indata): if __name__ == "__main__": - # test_flatten() - # test_reshape() - # test_shape() - # test_expand() - # test_power() - # test_squeeze() - # test_unsqueeze() - # test_slice() - # test_floor() - # test_ceil() - # test_round() - # test_isinf() - # test_isnan() - # test_clip() - # test_clip_min_max_as_inputs() - # test_onehot() - # test_matmul() - # test_gather() - # test_gatherelements() - # test_gather_nd() - # test_scatter() - # test_lrn() - # test_instance_norm() - # test_upsample() - # test_forward_min() - # test_forward_max() - # test_forward_mean() - # test_forward_hardsigmoid() - # test_forward_arg_min_max() - # test_softmax() - # test_constantofshape() - # test_all_reduce_funcs() - # test_pad() - # test_split() - # test_binary_ops() - # test_unary_ops() - # test_leaky_relu() - # test_elu() - # test_selu() - # test_prelu() - # test_ThresholdedRelu() - # test_LogSoftmax() - # test_resnet() - # test_inception() - # test_densenet() - # test_sign() - # test_not() - # test_and() - # test_tile() - # test_erf() - # test_where() - # test_or() - # test_depth_to_space() - # test_space_to_depth() - # test_batch_norm() - # test_batch_norm_dynamic_subgraph() - # test_conv() - # test_convtranspose() - # test_unsqueeze_constant() - # test_pooling() - # test_lppool() - # test_lstm() - # test_gru() - # test_resize() - # test_nonzero() - # test_topk() - # test_mod() - # test_xor() + test_flatten() + test_reshape() + test_shape() + test_expand() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_round() + test_isinf() + test_isnan() + test_clip() + test_clip_min_max_as_inputs() + test_onehot() + test_matmul() + test_gather() + test_gatherelements() + test_gather_nd() + test_scatter() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_all_reduce_funcs() + test_pad() + test_split() + test_binary_ops() + test_unary_ops() + test_leaky_relu() + test_elu() + test_selu() + test_prelu() + test_ThresholdedRelu() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() + test_conv() + test_convtranspose() + test_unsqueeze_constant() + test_pooling() + test_lppool() + test_lstm() + test_gru() + test_resize() + test_nonzero() + test_topk() + test_mod() + test_xor() test_max_roi_pool() test_roi_align() - exit() test_range() test_loop() test_size() diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 6d7d401d706b..95cd537091f5 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -583,7 +583,7 @@ def test_threshold(): @tvm.testing.uses_gpu def test_roi_align(): - def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio): + def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio, mode): data = relay.var("data", relay.ty.TensorType(data_shape, "float32")) rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32")) z = relay.vision.roi_align( @@ -592,6 +592,7 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ pooled_size=(pooled_size, pooled_size), spatial_scale=spatial_scale, sample_ratio=sample_ratio, + mode=mode, layout="NCHW", ) zz = run_infer_type(z) @@ -612,6 +613,7 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=sample_ratio, + mode=mode, ) for target, ctx in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) @@ -621,8 +623,18 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ op_res2 = intrp2.evaluate(func)(np_data, np_rois) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-4) - verify_roi_align((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0, sample_ratio=-1) - verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2) + verify_roi_align( + (1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0, sample_ratio=-1, mode="avg" + ) + verify_roi_align( + (4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2, mode="avg" + ) + verify_roi_align( + (1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0, sample_ratio=-1, mode="max" + ) + verify_roi_align( + (4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2, mode="max" + ) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 5ec289ec6390..0b16c75a33ea 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -418,7 +418,9 @@ def check_device(device): check_device(device) -def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio, mode): # 0 = avg, 1 = max +def verify_roi_align( + batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio, mode +): # 0 = avg, 1 = max a_shape = (batch, in_channel, in_size, in_size) rois_shape = (num_roi, 5) @@ -430,16 +432,13 @@ def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype("float32") rois_np = np.random.uniform(size=rois_shape).astype("float32") * in_size rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi) - print(type(a_np)) - print(type(sample_ratio)) - print("type mode: ", type(mode)) b_np = tvm.topi.testing.roi_align_nchw_python( a_np, rois_np, pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=sample_ratio, - mode=mode + mode=mode, ) return a_np, rois_np, b_np @@ -451,7 +450,6 @@ def check_device(device): if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return - print("Mode in check device: ", mode) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _roi_align_implement) b = fcompute( @@ -460,7 +458,7 @@ def check_device(device): pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=sample_ratio, - mode=mode + mode=mode, ) s = fschedule(b) @@ -470,8 +468,6 @@ def check_device(device): f = tvm.build(s, [a, rois, b], device) f(tvm_a, tvm_rois, tvm_b) tvm_val = tvm_b.asnumpy() - #print("Tvm val: ", tvm_val) - #print("B_np: ", b_np) tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3) for device in ["llvm", "cuda", "opencl"]: @@ -628,10 +624,10 @@ def test_proposal(): if __name__ == "__main__": - #test_get_valid_counts() - #test_multibox_prior() - #test_multibox_detection() + test_get_valid_counts() + test_multibox_prior() + test_multibox_detection() test_roi_align() test_roi_pool() - #test_proposal() - #test_non_max_suppression() + test_proposal() + test_non_max_suppression() From 1901fc6abd83e95f28a58f5a26fc1176226615df Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 10 Feb 2021 18:46:19 -0800 Subject: [PATCH 04/13] fix lint --- include/tvm/relay/attrs/vision.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index e030fb6a689c..4a96d391430e 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -141,8 +141,7 @@ struct ROIAlignAttrs : public tvm::AttrsNode { "dimensions respectively. Convolution is applied on the 'H' and" "'W' dimensions."); TVM_ATTR_FIELD(mode).set_default("avg").describe( - "Mode for ROI Align. Can be 'avg' or 'max'. The default mode is 'avg'." - ); + "Mode for ROI Align. Can be 'avg' or 'max'. The default mode is 'avg'."); } }; From b5c12cb6b1f2f40a37eba4037459c6e2842bad6c Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 10 Feb 2021 19:37:48 -0800 Subject: [PATCH 05/13] lint again --- python/tvm/relay/op/vision/rcnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/vision/rcnn.py b/python/tvm/relay/op/vision/rcnn.py index ffa9487c4ffa..d25c5de89cee 100644 --- a/python/tvm/relay/op/vision/rcnn.py +++ b/python/tvm/relay/op/vision/rcnn.py @@ -39,7 +39,7 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout="N sample_ratio : int Optional sampling ratio of ROI align, using adaptive size by default. - + mode : str, Optional The pooling method. Relay supports two methods, 'avg' and 'max'. Default is 'avg'. From 86d257f79da51057862e87e92e8727bf8b59ac12 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 10 Feb 2021 20:56:02 -0800 Subject: [PATCH 06/13] lint --- python/tvm/topi/vision/rcnn/roi_align.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/vision/rcnn/roi_align.py b/python/tvm/topi/vision/rcnn/roi_align.py index 324a71f3bdf4..6d139214081c 100644 --- a/python/tvm/topi/vision/rcnn/roi_align.py +++ b/python/tvm/topi/vision/rcnn/roi_align.py @@ -110,16 +110,16 @@ def _sample(i, c, ph, pw): / count, axis=[rh, rw], ) - elif max_mode: - return te.max( - _bilinear( - batch_index, - c, - roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, - roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w, - ), - axis=[rh, rw], - ) + # max mode + return te.max( + _bilinear( + batch_index, + c, + roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, + roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w, + ), + axis=[rh, rw], + ) return te.compute( (num_roi, channel, pooled_size_h, pooled_size_w), _sample, tag="pool,roi_align_nchw" From ca4f62ce9f9db91ebf5fbef97461268aa5def299 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 11 Feb 2021 08:10:03 -0800 Subject: [PATCH 07/13] lint --- python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/topi/testing/roi_align_python.py | 4 ++-- python/tvm/topi/vision/rcnn/roi_align.py | 4 ++-- python/tvm/topi/x86/roi_align.py | 5 ++--- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index bd5936cb9bac..abcbab063b64 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2035,7 +2035,7 @@ def _impl_v1(cls, inputs, attr, params): rois = inputs[1] batch_indices = inputs[2] mode = attr.get("mode", b"avg") - if mode != b"avg" and mode != b"max": + if not in (b"avg", b"max"): raise ValueError("RoiAlign in Relay only uses avg and max modes") output_height = attr.get("output_height", 1) output_width = attr.get("output_width", 1) diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index ad8dc09abe24..1d56ddd35e4a 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -22,8 +22,8 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg"): """Roi align in python""" - avg_mode = mode == b"avg" or mode == "avg" or mode == 0 - max_mode = mode == b"max" or mode == "max" or mode == 1 + avg_mode = mode in (b"avg", "avg", 0) + max_mode = mode in (b"max", "max", 1) assert avg_mode or max_mode, "Mode must be average or max. Please pass a valid mode." _, channel, height, width = a_np.shape num_roi = rois_np.shape[0] diff --git a/python/tvm/topi/vision/rcnn/roi_align.py b/python/tvm/topi/vision/rcnn/roi_align.py index 6d139214081c..95f350084ba5 100644 --- a/python/tvm/topi/vision/rcnn/roi_align.py +++ b/python/tvm/topi/vision/rcnn/roi_align.py @@ -53,8 +53,8 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 output : tvm.te.Tensor 4-D with shape [num_roi, channel, pooled_size, pooled_size] """ - avg_mode = mode == b"avg" or mode == 0 - max_mode = mode == b"max" or mode == 1 + avg_mode = mode in (b"avg", 0) + max_mode = mode in (b"max", 1) assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a valid mode." dtype = rois.dtype _, channel, height, width = get_const_tuple(data.shape) diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index 738be505acb3..b7799774883c 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements """Non-maximum suppression operator for intel cpu""" import math -import tvm from tvm import relay from tvm.te import hybrid @@ -296,9 +295,9 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, sample_ratio=-1 pooled_size = tvm.runtime.convert(pooled_size) spatial_scale = tvm.tir.const(spatial_scale, "float32") sample_ratio = tvm.tir.const(sample_ratio, "int32") - if mode == b"avg" or mode == 0: + if mode in (b"avg", 0): mode = tvm.tir.const(0, dtype="float32") - elif mode == b"max" or mode == 1: + elif mode in (b"max", 1): mode = tvm.tir.const(1, dtype="float32") else: raise ValueError(mode, "Value %s passed in for mode not supported", mode) From f8ff52c0654afeabe467969d0d5681061f503fb4 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 11 Feb 2021 09:58:47 -0800 Subject: [PATCH 08/13] typo --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index abcbab063b64..07b3913f8dba 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2035,7 +2035,7 @@ def _impl_v1(cls, inputs, attr, params): rois = inputs[1] batch_indices = inputs[2] mode = attr.get("mode", b"avg") - if not in (b"avg", b"max"): + if mode not in (b"avg", b"max"): raise ValueError("RoiAlign in Relay only uses avg and max modes") output_height = attr.get("output_height", 1) output_width = attr.get("output_width", 1) From f12907310392a1596ddeda00b74db08d9aa1e5c3 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 11 Feb 2021 11:22:53 -0800 Subject: [PATCH 09/13] remove import --- python/tvm/topi/x86/roi_align.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index b7799774883c..95fd863d418b 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -18,7 +18,6 @@ """Non-maximum suppression operator for intel cpu""" import math -from tvm import relay from tvm.te import hybrid from ..tensor import full from ..utils import get_const_tuple From 298d6d933d8ef46e98fc39d89f214b4f4f64aacf Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 11 Feb 2021 15:06:02 -0800 Subject: [PATCH 10/13] fix import --- python/tvm/topi/x86/roi_align.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index 95fd863d418b..8116f88208f6 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -18,6 +18,7 @@ """Non-maximum suppression operator for intel cpu""" import math +import tvm from tvm.te import hybrid from ..tensor import full from ..utils import get_const_tuple From 61c6e0ebd24bc22cb5cc5670ba047041e5b66161 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 12 Feb 2021 11:38:14 -0800 Subject: [PATCH 11/13] add inf, -inf to hybridscript and respond to comments --- python/tvm/te/hybrid/calls.py | 14 +++ python/tvm/te/hybrid/runtime.py | 10 ++ python/tvm/topi/testing/roi_align_python.py | 23 ++-- python/tvm/topi/x86/roi_align.py | 111 +++++++------------ tests/python/topi/python/test_topi_vision.py | 8 +- 5 files changed, 78 insertions(+), 88 deletions(-) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 6785457c3bd7..462066106a9d 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -167,3 +167,17 @@ def max_num_threads(func_id, args): _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") res = Target.current(args[0].value).max_num_threads return convert(res) + + +def inf(func_id, args): + """Infinity""" + _internal_assert(func_id == "inf", "This function cannot be directly invoked!") + _internal_assert(args.__len__() == 1, "One argument accepted!") + return tvm.tir.max_value(args[0]) + + +def ninf(func_id, args): + """Negative infinity""" + _internal_assert(func_id == "ninf", "This function cannot be directly invoked!") + _internal_assert(args.__len__() == 1, "One argument accepted!") + return tvm.tir.min_value(args[0]) diff --git a/python/tvm/te/hybrid/runtime.py b/python/tvm/te/hybrid/runtime.py index 7b90f8729014..615bd7e43a7d 100644 --- a/python/tvm/te/hybrid/runtime.py +++ b/python/tvm/te/hybrid/runtime.py @@ -111,6 +111,14 @@ def max_num_threads(allow_none=True): return Target.current(allow_none).max_num_threads +def inf(dtype): + return numpy.iinfo(dtype).max + + +def ninf(dtype): + return numpy.iinfo(dtype).min + + HYBRID_GLOBALS = { "unroll": range, "vectorize": range, @@ -142,6 +150,8 @@ def max_num_threads(allow_none=True): "float64": numpy.float64, "ceil_div": lambda a, b: (a + b - 1) // b, "max_num_threads": max_num_threads, + "inf": inf, + "ninf": inf, } diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index 1d56ddd35e4a..81e7bfcf1815 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -80,18 +80,17 @@ def _bilinear(n, c, y, x): for pw in range(pooled_size_w): if avg_mode: total = 0.0 - for iy in range(roi_bin_grid_h): - for ix in range(roi_bin_grid_w): - y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h - x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w + if max_mode: + total = float("-inf") + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h + x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w + if avg_mode: total += _bilinear(batch_index, c, y, x) - b_np[i, c, ph, pw] = total / count - elif max_mode: - total = 0.0 - for iy in range(roi_bin_grid_h): - for ix in range(roi_bin_grid_w): - y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h - x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w + if max_mode: total = max(total, _bilinear(batch_index, c, y, x)) - b_np[i, c, ph, pw] = total + if avg_mode: + total = total / count + b_np[i, c, ph, pw] = total return b_np diff --git a/python/tvm/topi/x86/roi_align.py b/python/tvm/topi/x86/roi_align.py index 8116f88208f6..336a336f50e5 100644 --- a/python/tvm/topi/x86/roi_align.py +++ b/python/tvm/topi/x86/roi_align.py @@ -166,78 +166,45 @@ def roi_align_nchw_ir( pre_calc_index = 0 for ph in range(pooled_size_h): for pw in range(pooled_size_w): - output_val = 0.0 - if mode == 0: - for iy in range(roi_bin_grid_h): - for ix in range(roi_bin_grid_w): - output_val += ( - w_pc[n, pre_calc_index, 0] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 0], - ] - + w_pc[n, pre_calc_index, 1] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 1], - ] - + w_pc[n, pre_calc_index, 2] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 0], - ] - + w_pc[n, pre_calc_index, 3] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 1], - ] - ) - pre_calc_index += 1 - - output_val /= count - output[n, c, ph, pw] = output_val - elif mode == 1: - for iy in range(roi_bin_grid_h): - for ix in range(roi_bin_grid_w): - bilinear_val = ( - w_pc[n, pre_calc_index, 0] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 0], - ] - + w_pc[n, pre_calc_index, 1] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 2], - pos_pc[n, pre_calc_index, 1], - ] - + w_pc[n, pre_calc_index, 2] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 0], - ] - + w_pc[n, pre_calc_index, 3] - * data[ - roi_batch_index, - c, - pos_pc[n, pre_calc_index, 3], - pos_pc[n, pre_calc_index, 1], - ] - ) - pre_calc_index += 1 + output_val = 0.0 # Avg mode + if mode == 1: # Max mode + output_val = ninf("float32") + for iy in range(roi_bin_grid_h): + for ix in range(roi_bin_grid_w): + bilinear_val = ( + w_pc[n, pre_calc_index, 0] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 0], + ] + + w_pc[n, pre_calc_index, 1] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 2], + pos_pc[n, pre_calc_index, 1], + ] + + w_pc[n, pre_calc_index, 2] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 0], + ] + + w_pc[n, pre_calc_index, 3] + * data[ + roi_batch_index, + c, + pos_pc[n, pre_calc_index, 3], + pos_pc[n, pre_calc_index, 1], + ] + ) + pre_calc_index += 1 + if mode == 0: # Avg mode + output_val += bilinear_val / count + if mode == 1: # Max mode output_val = max(output_val, bilinear_val) output[n, c, ph, pw] = output_val return output diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 0b16c75a33ea..aa4a6173034f 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -420,7 +420,7 @@ def check_device(device): def verify_roi_align( batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio, mode -): # 0 = avg, 1 = max +): # For mode, 0 = avg, 1 = max a_shape = (batch, in_channel, in_size, in_size) rois_shape = (num_roi, 5) @@ -429,8 +429,8 @@ def verify_roi_align( @memoize("topi.tests.test_topi_vision.verify_roi_align") def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype("float32") - rois_np = np.random.uniform(size=rois_shape).astype("float32") * in_size + a_np = np.random.uniform(-1, 1, size=a_shape).astype("float32") + rois_np = np.random.uniform(-1, 1, size=rois_shape).astype("float32") * in_size rois_np[:, 0] = np.random.randint(low=0, high=batch, size=num_roi) b_np = tvm.topi.testing.roi_align_nchw_python( a_np, @@ -468,7 +468,7 @@ def check_device(device): f = tvm.build(s, [a, rois, b], device) f(tvm_a, tvm_rois, tvm_b) tvm_val = tvm_b.asnumpy() - tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3) + tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3, atol=1e-5) for device in ["llvm", "cuda", "opencl"]: check_device(device) From 14a364f9845e15f80c6b4bbf7ad5d4a9b65afe7b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 12 Feb 2021 11:43:20 -0800 Subject: [PATCH 12/13] shorten code --- python/tvm/topi/testing/roi_align_python.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/testing/roi_align_python.py b/python/tvm/topi/testing/roi_align_python.py index 81e7bfcf1815..643a954b101b 100644 --- a/python/tvm/topi/testing/roi_align_python.py +++ b/python/tvm/topi/testing/roi_align_python.py @@ -87,10 +87,8 @@ def _bilinear(n, c, y, x): y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w if avg_mode: - total += _bilinear(batch_index, c, y, x) + total += _bilinear(batch_index, c, y, x) / count if max_mode: total = max(total, _bilinear(batch_index, c, y, x)) - if avg_mode: - total = total / count b_np[i, c, ph, pw] = total return b_np From 33d677f9f0fed1e33a8b2873ca0f5ed2ea0d7117 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 16 Feb 2021 09:14:46 -0800 Subject: [PATCH 13/13] make atol lower --- tests/python/topi/python/test_topi_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index aa4a6173034f..839356892ab1 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -468,7 +468,7 @@ def check_device(device): f = tvm.build(s, [a, rois, b], device) f(tvm_a, tvm_rois, tvm_b) tvm_val = tvm_b.asnumpy() - tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3, atol=1e-5) + tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3, atol=1e-4) for device in ["llvm", "cuda", "opencl"]: check_device(device)