Skip to content

Commit

Permalink
[CUDA] [Codegen] Ensuring atleast one thread block for dynamism (#7273)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jan 14, 2021
1 parent c11959d commit 8d3c0e7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 5 additions & 1 deletion src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,11 @@ class ThreadAxisConfig {
ThreadWorkLoad w;
std::fill(w.work_size, w.work_size + 6, 1);
for (size_t i = 0; i < arg_index_map_.size(); ++i) {
w.work_size[arg_index_map_[i]] = static_cast<size_t>(x.values[base_ + i].v_int64);
// Dynamic shapes can result in 0 dim size. Guard to ensure that the dim size is atleast 1.
size_t size = static_cast<size_t>(x.values[base_ + i].v_int64);
if (size > 0) {
w.work_size[arg_index_map_[i]] = size;
}
}
return w;
}
Expand Down
9 changes: 6 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def test_any_softmax():
verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))


def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
def verify_any_topk(data_shape, kval, np_dshape, dtype, ret_type="indices", const_k=False):
mod = tvm.IRModule()
data = relay.var("data", shape=data_shape, dtype=dtype)
np_data = np.random.uniform(size=np_dshape).astype(dtype)
Expand All @@ -857,7 +857,9 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
k = relay.var("k", shape=(), dtype="int32")
args = [data, k]
in_vals = [np_data, kval]
out = relay.topk(data, k, ret_type="indices")
out = relay.topk(data, k, ret_type=ret_type)
if ret_type == "both":
out = out[0]
mod["main"] = relay.Function(args, out)

sorted = np.argsort(-np_data)
Expand All @@ -873,7 +875,8 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
def test_any_topk():
verify_any_topk(any_dims(1), 5, (10,), "float32")
verify_any_topk(any_dims(2), 2, (6, 3), "int32")
verify_any_topk(any_dims(2), 3, (6, 3), "float32", True)
verify_any_topk(any_dims(2), 3, (6, 3), "float32", const_k=True)
verify_any_topk(any_dims(1), 0, (0,), "float32", ret_type="both")


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 8d3c0e7

Please sign in to comment.