Skip to content

Commit

Permalink
Update sort and nms ir
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jan 31, 2019
1 parent 8dd4bfa commit 7d37fff
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions topi/python/topi/cuda/rcnn/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def argsort_ir(data_buf, out_index_buf):
p_data = ib.buffer_ptr(data_buf)
index_out = ib.buffer_ptr(out_index_buf)
nthread_tx = max_threads
nthread_bx = num_bbox // max_threads + 1
nthread_bx = (num_bbox + 1) // 2 // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx)
Expand All @@ -149,9 +149,10 @@ def argsort_ir(data_buf, out_index_buf):

with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_bbox
with ib.if_scope(tid < num_bbox):
index_out[start + tid] = tid

for i in range(2):
bbox_id = tid * 2 + i
with ib.if_scope(bbox_id < num_bbox):
index_out[start + bbox_id] = bbox_id
with ib.for_range(0, num_bbox) as k:
offset = start + 2 * tid + (k % 2)
with ib.if_scope(
Expand Down Expand Up @@ -213,17 +214,16 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
nthread_bx = num_bbox // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
j = bx * max_threads + tx
i = bx * max_threads + tx
with ib.for_range(0, batch, for_type="unroll", name="n") as b:
start = b * num_bbox
with ib.if_scope(j < num_bbox):
p_out[start + j] = False

with ib.for_range(0, num_bbox - 1) as i:
with ib.if_scope(tvm.all(j < num_bbox, j > i, p_out[start + i] == False)):
iou = calculate_overlap(p_data, (start + i) * 5, (start + j) * 5)
base_idx = b * num_bbox
with ib.if_scope(i < num_bbox):
p_out[base_idx + i] = False
with ib.for_range(0, num_bbox - 1) as l:
with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
with ib.if_scope(iou > nms_threshold):
p_out[start + j] = True
p_out[base_idx + i] = True
return ib.get()


Expand Down

0 comments on commit 7d37fff

Please sign in to comment.