Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Parallelize GPU NMS inner loop #7172

Merged
merged 7 commits into from
Dec 30, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,26 +512,44 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):

with ib.new_scope():
nthread_by = batch_size
nthread_tx = max_threads

by = te.thread_axis("blockIdx.y")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(tx, "thread_extent", nthread_tx)

i = by

base_idx = i * num_anchors * box_data_length
num_valid_boxes_local = ib.allocate(
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
num_valid_boxes_local[0] = 0
nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i])

def nms_inner_loop(ib, j):
# The box j is valid, invalidate other boxes that overlap with j above iou_threshold

# When return_indices is False, no need to populate box_indices
if return_indices:
with ib.if_scope(tx + 0 == 0):
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]

masahi marked this conversation as resolved.
Show resolved Hide resolved
num_valid_boxes_local[0] += 1

offset_j = j * box_data_length
num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)

with ib.for_range(0, j) as k:
with ib.for_range(0, num_iter_per_thread) as _k:
k = j + 1 + _k * nthread_tx + tx
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_j + score_index] > -1.0, # if already surpressed
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
k < nkeep,
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
tvm.tir.any(
force_suppress > 0,
id_index < 0,
Expand All @@ -546,35 +564,31 @@ def nms_inner_loop(ib, j):
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
# invalidate the box k
out[base_idx + offset_k + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0
out[base_idx + offset_k + id_index] = -1.0

# Has the box j survived IOU tests?
with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
# When return_indices is False, no need to populate box_indices
if return_indices:
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
num_valid_boxes_local[0] += 1
# Make sure to do the next loop in a lock step
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))

if isinstance(max_output_size, int):
max_output_size = tvm.tir.const(max_output_size)

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(
tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
):
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
nms_inner_loop(ib, j)

num_valid_boxes[i] = num_valid_boxes_local[0]
with ib.if_scope(tx + 0 == 0):
num_valid_boxes[i] = num_valid_boxes_local[0]

with ib.else_scope():
num_valid_boxes[i] = 0
Expand Down