diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 3cdc02e58aec..e0d71559f1a0 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -35,7 +35,7 @@ def sort_ir(data, index, output): p_index = ib.buffer_ptr(index) p_out = ib.buffer_ptr(output) nthread_tx = max_threads - nthread_bx = (num_anchors + 1) // 2 // max_threads + 1 + nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("vthread") ib.scope_attr(tx, "thread_extent", nthread_tx) @@ -46,10 +46,8 @@ def sort_ir(data, index, output): with ib.for_range(0, batch, for_type="unroll") as b: start = b * num_anchors - for i in range(2): - bbox_id = tid * 2 + i - with ib.if_scope(bbox_id < num_anchors): - p_out[start + bbox_id] = bbox_id + with ib.if_scope(tid < num_anchors): + p_out[start + tid] = tid # OddEvenTransposeSort with ib.for_range(0, p_index[b]) as k: with ib.if_scope(tid < (p_index[b] + 1) // 2):