From 670d3a0ed293fe244c9ab1424b898a4025bda266 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Thu, 7 Mar 2019 14:46:35 -0800 Subject: [PATCH] revert PR#2420 nms changes --- topi/python/topi/cuda/nms.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 3cdc02e58aec..e0d71559f1a0 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -35,7 +35,7 @@ def sort_ir(data, index, output): p_index = ib.buffer_ptr(index) p_out = ib.buffer_ptr(output) nthread_tx = max_threads - nthread_bx = (num_anchors + 1) // 2 // max_threads + 1 + nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("vthread") ib.scope_attr(tx, "thread_extent", nthread_tx) @@ -46,10 +46,8 @@ def sort_ir(data, index, output): with ib.for_range(0, batch, for_type="unroll") as b: start = b * num_anchors - for i in range(2): - bbox_id = tid * 2 + i - with ib.if_scope(bbox_id < num_anchors): - p_out[start + bbox_id] = bbox_id + with ib.if_scope(tid < num_anchors): + p_out[start + tid] = tid # OddEvenTransposeSort with ib.for_range(0, p_index[b]) as k: with ib.if_scope(tid < (p_index[b] + 1) // 2):