diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2dc177a0fae8..dd9d3f8a1d0e 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -95,7 +95,7 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) + nthread_bx = num_anchors // max_threads + 1 nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -151,103 +151,31 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): valid_indices = ib.buffer_ptr(valid_indices) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - - # Copy boxes to valid_indices with ib.new_scope(): nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size + nthread_bx = batch_size // max_threads + 1 tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - tid = bx * nthread_tx + tx - with ib.if_scope(tid < num_anchors): - valid_indices[by, tid] = valid_boxes[by, tid] - - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - - ## The following algorithm performs parallel exclusive scan to get - ## a tensor that can later be used to select valid indices - # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << l2_width - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(start[0] < num_anchors): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ - by * num_anchors + middle[0] - 1 - ] - - # Down Sweep of exclusive scan - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] - valid_indices[(bx + 1) * num_anchors - 1] = 0 - - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << (lim - l2_width - 1) - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate("int32", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.tir.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] - valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[ - by * num_anchors + end[0] - 1 - ] - valid_indices[by * num_anchors + end[0] - 1] += tmp[0] - + tid = bx * max_threads + tx + # TODO(mbrookhart): Parallelize the sum and cumsum here + current_index = ib.allocate("int32", (1,), name="current_index", scope="local") + with ib.if_scope(tid < batch_size): + current_index[0] = 0 + valid_count[tid] = 0 + with ib.for_range(0, num_anchors) as j: + idx = tid * num_anchors + j + valid_count[tid] = valid_count[tid] + valid_boxes[idx] + with ib.if_scope(valid_boxes[idx] == 1): + valid_indices[idx] = current_index[0] + current_index[0] = current_index[0] + 1 + with ib.else_scope(): + valid_indices[idx] = -1 return ib.get() -def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): +def get_valid_counts_ir(data, valid_indices, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -275,9 +203,8 @@ def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) - valid_indices = ib.buffer_ptr(valid_indices) - valid_boxes = ib.buffer_ptr(valid_boxes) + valid_indices = ib.buffer_ptr(valid_indices) out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) one = tvm.tir.const(1, dtype=out.dtype) @@ -286,36 +213,41 @@ def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 nthread_by = batch_size + nthread_bz = elem_length with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid - with ib.for_range(0, elem_length) as k: - out[(i * num_anchors + j) * elem_length + k] = -one + k = bz + out[(i * num_anchors + j) * elem_length + k] = -one out_indices[i * num_anchors + j] = -1 with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid - with ib.if_scope(valid_boxes[i, tid] > 0): - with ib.for_range(0, elem_length) as k: - out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ - (i * num_anchors + j) * elem_length + k - ] + k = bz + with ib.if_scope(valid_indices[i, tid] >= 0): + out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] out_indices[i * num_anchors + valid_indices[i, tid]] = j return ib.get() @@ -389,10 +321,10 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): out, out_indices = te.extern( [data.shape, (batch_size, num_anchors)], - [data, valid_indices, valid_boxes], - lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), + [data, valid_indices], + lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]), dtype=["int32", data.dtype], - in_buffers=[data_buf, valid_indices_buf, valid_boxes_buf], + in_buffers=[data_buf, valid_indices_buf], out_buffers=[out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu", diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index cbf136a5552c..035d19f25ec7 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -213,7 +213,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): out_indices: tvm.te.Tensor or numpy NDArray Related index in input data. """ - if isinstance(score_threshold, (float, int)): + if isinstance(score_threshold, float): score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype) id_index_const = tvm.tir.const(id_index, "int32") score_index_const = tvm.tir.const(score_index, "int32") diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index cdf3b240507b..1ce8a182f034 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -313,8 +313,10 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) - tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) + # get_valid_count for opencl doesn't do data rearrangement + if target in ["opencl"]: + return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 697ef8a24f67..778843be37de 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -105,18 +105,27 @@ def check_device(device): tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx) tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx) - - f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) - f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) - tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) + if device == "llvm": + f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) + f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) + else: + f = tvm.build(s, [data, outs[0], outs[1]], device) + f(tvm_input_data, tvm_out1, tvm_out2) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) for device in ["llvm", "cuda", "opencl"]: check_device(device) @tvm.testing.uses_gpu +@pytest.mark.skip( + "Skip this test as it is intermittent." + "See https://github.com/apache/tvm/pull/4901#issuecomment-595040094" +) def test_get_valid_counts(): verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1)