From 2f84cbcda30d8e5df22ae05e8a548ab5ccc2e773 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 30 Dec 2020 18:00:22 +0900 Subject: [PATCH] [TOPI] Parallelize GPU NMS inner loop (#7172) * make NMS inner loop parallel * use one block two avoid global sync issue * temp disable write by only thread 0 * leave a TODO on write by only one thread * add some comments, remove check the check on negative class id * minor improvement when topk is available * fix write by a single thread --- python/tvm/topi/cuda/nms.py | 50 ++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 020cf9b5bc63..dd9d3f8a1d0e 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -512,26 +512,44 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_by = batch_size + nthread_tx = max_threads + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + i = by + base_idx = i * num_anchors * box_data_length num_valid_boxes_local = ib.allocate( "int32", (1,), name="num_valid_boxes_local", scope="local" ) num_valid_boxes_local[0] = 0 + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) def nms_inner_loop(ib, j): + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold + + # When return_indices is False, no need to populate box_indices + if return_indices: + with ib.if_scope(tx + 0 == 0): + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + + num_valid_boxes_local[0] += 1 + offset_j = j * box_data_length + num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) - with ib.for_range(0, j) as k: + with ib.for_range(0, num_iter_per_thread) as _k: + k = j + 1 + _k * nthread_tx + tx offset_k = k * box_data_length with ib.if_scope( tvm.tir.all( - out[base_idx + offset_j + score_index] > -1.0, # if already surpressed - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + k < nkeep, + out[base_idx + offset_k + score_index] > 0, # is the box k still valid? tvm.tir.any( force_suppress > 0, id_index < 0, @@ -546,27 +564,22 @@ def nms_inner_loop(ib, j): base_idx + offset_k + coord_start, ) with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 + # invalidate the box k + out[base_idx + offset_k + score_index] = -1.0 with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 + out[base_idx + offset_k + id_index] = -1.0 - # Has the box j survived IOU tests? - with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): - # When return_indices is False, no need to populate box_indices - if return_indices: - orig_idx = sorted_index[i * num_anchors + j] - box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - num_valid_boxes_local[0] += 1 + # Make sure to do the next loop in a lock step + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): max_output_size = tvm.tir.const(max_output_size) with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, valid_count[i]) as j: - with ib.if_scope( - tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0) - ): + with ib.for_range(0, nkeep) as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes with ib.if_scope(num_valid_boxes_local[0] < max_output_size): @@ -574,7 +587,8 @@ def nms_inner_loop(ib, j): with ib.else_scope(): nms_inner_loop(ib, j) - num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.if_scope(tx + 0 == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0