Skip to content

Commit

Permalink
make the scan loop exclusive
Browse files Browse the repository at this point in the history
  • Loading branch information
masa authored and mbrookhart committed Dec 18, 2020
1 parent 9d729ca commit 95c0f61
Showing 1 changed file with 22 additions and 28 deletions.
50 changes: 22 additions & 28 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,20 @@ def ceil_div(a, b):
# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
valid_indices[tid] = valid_boxes[tid]
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid == 0):
valid_indices[by, 0] = 0
with ib.else_scope():
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid - 1]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
Expand Down Expand Up @@ -304,29 +310,16 @@ def ceil_div(a, b):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
valid_count[tid] = valid_indices[tid * num_anchors + num_anchors - 1]

## Remove invalid indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
with ib.if_scope(valid_boxes[tid] < 1):
# if this is an invalid box, mark -1
valid_indices[tid] = -1
with ib.else_scope():
# if this is a valid box, subtract 1 to get 0-based indexing
valid_indices[tid] += -1
# Add valid_boxes[tid, num_anchors - 1] because valid_indices is
# an exclusive scan of valid_boxes
valid_count[tid] = (
valid_indices[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1]
)

return ib.get()


def get_valid_counts_ir(data, valid_indices, out, out_indices):
def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
top of input data.
Expand Down Expand Up @@ -354,8 +347,9 @@ def get_valid_counts_ir(data, valid_indices, out, out_indices):
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)

valid_indices = ib.buffer_ptr(valid_indices)
valid_boxes = ib.buffer_ptr(valid_boxes)

out = ib.buffer_ptr(out)
out_indices = ib.buffer_ptr(out_indices)
one = tvm.tir.const(1, dtype=out.dtype)
Expand Down Expand Up @@ -395,7 +389,7 @@ def get_valid_counts_ir(data, valid_indices, out, out_indices):
i = by
j = tid
k = bz
with ib.if_scope(valid_indices[i, tid] >= 0):
with ib.if_scope(valid_boxes[i, tid] > 0):
out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[
(i * num_anchors + j) * elem_length + k
]
Expand Down Expand Up @@ -472,10 +466,10 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):

out, out_indices = te.extern(
[data.shape, (batch_size, num_anchors)],
[data, valid_indices],
lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]),
[data, valid_indices, valid_boxes],
lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], ins[2], outs[0], outs[1]),
dtype=["int32", data.dtype],
in_buffers=[data_buf, valid_indices_buf],
in_buffers=[data_buf, valid_indices_buf, valid_boxes_buf],
out_buffers=[out_buf, out_indices_buf],
name="get_valid_counts",
tag="get_valid_counts_gpu",
Expand Down

0 comments on commit 95c0f61

Please sign in to comment.