From f8afb1e7f698e4946e59565359c298d1e0888b4d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 26 Dec 2020 07:59:19 +0900 Subject: [PATCH 1/7] make NMS inner loop parallel --- python/tvm/topi/cuda/nms.py | 82 ++++++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 020cf9b5bc63..56f63ce021c3 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -512,9 +512,18 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_by = batch_size + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + i = by + k = bx * nthread_tx + tx base_idx = i * num_anchors * box_data_length num_valid_boxes_local = ib.allocate( "int32", (1,), name="num_valid_boxes_local", scope="local" @@ -522,41 +531,43 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, j): - offset_j = j * box_data_length + # box j is valid, invalidate other boxes that overlap with j above iou_threshold - with ib.for_range(0, j) as k: - offset_k = k * box_data_length - - with ib.if_scope( - tvm.tir.all( - out[base_idx + offset_j + score_index] > -1.0, # if already surpressed - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 - - # Has the box j survived IOU tests? - with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): - # When return_indices is False, no need to populate box_indices - if return_indices: + # When return_indices is False, no need to populate box_indices + if return_indices: + # Only one thread needs to this write + with ib.if_scope(k == 0): orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - num_valid_boxes_local[0] += 1 + + num_valid_boxes_local[0] += 1 + + offset_j = j * box_data_length + offset_k = k * box_data_length + + with ib.if_scope( + tvm.tir.all( + j < k, + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index], + ), + ) + ): + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_k + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_k + id_index] = -1.0 + + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): max_output_size = tvm.tir.const(max_output_size) @@ -565,7 +576,12 @@ def nms_inner_loop(ib, j): # Apply nms with ib.for_range(0, valid_count[i]) as j: with ib.if_scope( - tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0) + tvm.tir.all( + out[base_idx + (j * box_data_length) + score_index] > -1.0, + tvm.tir.any( + id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0 + ), + ) ): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes From cabf8fe0907e021f46e1159c0c41f9f34767f85f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 26 Dec 2020 10:06:43 +0900 Subject: [PATCH 2/7] use one block two avoid global sync issue --- python/tvm/topi/cuda/nms.py | 64 ++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 56f63ce021c3..65f7e3950e1c 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -513,17 +513,14 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_by = batch_size nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) by = te.thread_axis("blockIdx.y") tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(tx, "thread_extent", nthread_tx) i = by - k = bx * nthread_tx + tx + base_idx = i * num_anchors * box_data_length num_valid_boxes_local = ib.allocate( "int32", (1,), name="num_valid_boxes_local", scope="local" @@ -531,43 +528,49 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, j): - # box j is valid, invalidate other boxes that overlap with j above iou_threshold + # the box j is valid, invalidate other boxes that overlap with j above iou_threshold # When return_indices is False, no need to populate box_indices if return_indices: # Only one thread needs to this write - with ib.if_scope(k == 0): + with ib.if_scope(tx == 0): orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 offset_j = j * box_data_length - offset_k = k * box_data_length + num_iter_per_thread = ceil_div(num_anchors - (j + 1), nthread_tx) - with ib.if_scope( - tvm.tir.all( - j < k, - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_k + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_k + id_index] = -1.0 + with ib.for_range(0, num_iter_per_thread) as _k: + k = j + 1 + _k * nthread_tx + tx + offset_k = k * box_data_length + + with ib.if_scope( + tvm.tir.all( + k < num_anchors, + out[base_idx + offset_k + score_index] > 0, # is the box k still valid? + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), + ) + ): + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out[base_idx + offset_k + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_k + id_index] = -1.0 - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): max_output_size = tvm.tir.const(max_output_size) @@ -590,7 +593,8 @@ def nms_inner_loop(ib, j): with ib.else_scope(): nms_inner_loop(ib, j) - num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.if_scope(tx == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0 From 1398eb4aa2f61b6d3ed1e223b008c3a47e1c70be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Dec 2020 19:13:04 +0900 Subject: [PATCH 3/7] temp disable write by only thread 0 --- python/tvm/topi/cuda/nms.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 65f7e3950e1c..210d5a5b1c76 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -530,12 +530,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): def nms_inner_loop(ib, j): # the box j is valid, invalidate other boxes that overlap with j above iou_threshold - # When return_indices is False, no need to populate box_indices - if return_indices: - # Only one thread needs to this write - with ib.if_scope(tx == 0): - orig_idx = sorted_index[i * num_anchors + j] - box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + # # When return_indices is False, no need to populate box_indices + # if return_indices: + # # Only one thread needs to this write + # with ib.if_scope(tx == 0): + # orig_idx = sorted_index[i * num_anchors + j] + # box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 @@ -593,8 +596,8 @@ def nms_inner_loop(ib, j): with ib.else_scope(): nms_inner_loop(ib, j) - with ib.if_scope(tx == 0): - num_valid_boxes[i] = num_valid_boxes_local[0] + # with ib.if_scope(tx == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0 From 5df9160ec09fd571c9bafa011edeee5dc4aa369e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Dec 2020 19:39:49 +0900 Subject: [PATCH 4/7] leave a TODO on write by only one thread --- python/tvm/topi/cuda/nms.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 210d5a5b1c76..916c2c758d44 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -530,20 +530,21 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): def nms_inner_loop(ib, j): # the box j is valid, invalidate other boxes that overlap with j above iou_threshold - # # When return_indices is False, no need to populate box_indices - # if return_indices: - # # Only one thread needs to this write - # with ib.if_scope(tx == 0): - # orig_idx = sorted_index[i * num_anchors + j] - # box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + # When return_indices is False, no need to populate box_indices + if return_indices: + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - orig_idx = sorted_index[i * num_anchors + j] - box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + # TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen + # # Only one thread needs to this write + # with ib.if_scope(tx == 0): + # orig_idx = sorted_index[i * num_anchors + j] + # box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 offset_j = j * box_data_length - num_iter_per_thread = ceil_div(num_anchors - (j + 1), nthread_tx) + num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx) with ib.for_range(0, num_iter_per_thread) as _k: k = j + 1 + _k * nthread_tx + tx @@ -552,7 +553,7 @@ def nms_inner_loop(ib, j): with ib.if_scope( tvm.tir.all( k < num_anchors, - out[base_idx + offset_k + score_index] > 0, # is the box k still valid? + out[base_idx + offset_k + score_index] > 0, # is the box k still valid? tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), tvm.tir.any( force_suppress > 0, @@ -596,8 +597,10 @@ def nms_inner_loop(ib, j): with ib.else_scope(): nms_inner_loop(ib, j) - # with ib.if_scope(tx == 0): num_valid_boxes[i] = num_valid_boxes_local[0] + # TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen + # with ib.if_scope(tx == 0): + # num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0 From 2073995ffe211c28f35a482905cb4722dcfac8c0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Dec 2020 19:50:36 +0900 Subject: [PATCH 5/7] add some comments, remove check the check on negative class id --- python/tvm/topi/cuda/nms.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 916c2c758d44..b8c76d5a880f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -528,14 +528,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, j): - # the box j is valid, invalidate other boxes that overlap with j above iou_threshold + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold # When return_indices is False, no need to populate box_indices if return_indices: orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - # TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen + # TODO(masahi): Want to do this instead of above, but the following is eliminated + # during codegen # # Only one thread needs to this write # with ib.if_scope(tx == 0): # orig_idx = sorted_index[i * num_anchors + j] @@ -554,7 +555,6 @@ def nms_inner_loop(ib, j): tvm.tir.all( k < num_anchors, out[base_idx + offset_k + score_index] > 0, # is the box k still valid? - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), tvm.tir.any( force_suppress > 0, id_index < 0, @@ -574,6 +574,7 @@ def nms_inner_loop(ib, j): with ib.if_scope(id_index >= 0): out[base_idx + offset_k + id_index] = -1.0 + # Make sure to do the next loop in a lock step ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): @@ -582,14 +583,8 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: - with ib.if_scope( - tvm.tir.all( - out[base_idx + (j * box_data_length) + score_index] > -1.0, - tvm.tir.any( - id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0 - ), - ) - ): + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes with ib.if_scope(num_valid_boxes_local[0] < max_output_size): @@ -598,7 +593,8 @@ def nms_inner_loop(ib, j): nms_inner_loop(ib, j) num_valid_boxes[i] = num_valid_boxes_local[0] - # TODO(masahi): Want to do this instead of above, but the following is eliminated during codegen + # TODO(masahi): Want to do this instead of above, but the following is eliminated + # during codegen # with ib.if_scope(tx == 0): # num_valid_boxes[i] = num_valid_boxes_local[0] From 919f40b69084085204d5896a2dc8e381f4fa305d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 29 Dec 2020 04:27:33 +0900 Subject: [PATCH 6/7] minor improvement when topk is available --- python/tvm/topi/cuda/nms.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index b8c76d5a880f..79654166fb91 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -526,6 +526,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): "int32", (1,), name="num_valid_boxes_local", scope="local" ) num_valid_boxes_local[0] = 0 + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) def nms_inner_loop(ib, j): # The box j is valid, invalidate other boxes that overlap with j above iou_threshold @@ -545,7 +546,7 @@ def nms_inner_loop(ib, j): num_valid_boxes_local[0] += 1 offset_j = j * box_data_length - num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx) + num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) with ib.for_range(0, num_iter_per_thread) as _k: k = j + 1 + _k * nthread_tx + tx @@ -553,7 +554,7 @@ def nms_inner_loop(ib, j): with ib.if_scope( tvm.tir.all( - k < num_anchors, + k < nkeep, out[base_idx + offset_k + score_index] > 0, # is the box k still valid? tvm.tir.any( force_suppress > 0, @@ -582,7 +583,7 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, valid_count[i]) as j: + with ib.for_range(0, nkeep) as j: # Proceed to the inner loop if the box j is still valid with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0): with ib.if_scope(max_output_size > 0): From f317ee23bb20ce44bd2564e70a8861c6e30b0219 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 30 Dec 2020 10:09:57 +0900 Subject: [PATCH 7/7] fix write by a single thread --- python/tvm/topi/cuda/nms.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 79654166fb91..dd9d3f8a1d0e 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -533,15 +533,9 @@ def nms_inner_loop(ib, j): # When return_indices is False, no need to populate box_indices if return_indices: - orig_idx = sorted_index[i * num_anchors + j] - box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - - # TODO(masahi): Want to do this instead of above, but the following is eliminated - # during codegen - # # Only one thread needs to this write - # with ib.if_scope(tx == 0): - # orig_idx = sorted_index[i * num_anchors + j] - # box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + with ib.if_scope(tx + 0 == 0): + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 @@ -593,11 +587,8 @@ def nms_inner_loop(ib, j): with ib.else_scope(): nms_inner_loop(ib, j) - num_valid_boxes[i] = num_valid_boxes_local[0] - # TODO(masahi): Want to do this instead of above, but the following is eliminated - # during codegen - # with ib.if_scope(tx == 0): - # num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.if_scope(tx + 0 == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0