diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 1917096bb24c..c0393600b60c 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -215,7 +215,11 @@ class ThreadAxisConfig { ThreadWorkLoad w; std::fill(w.work_size, w.work_size + 6, 1); for (size_t i = 0; i < arg_index_map_.size(); ++i) { - w.work_size[arg_index_map_[i]] = static_cast(x.values[base_ + i].v_int64); + // Dynamic shapes can result in 0 dim size. Guard to ensure that the dim size is atleast 1. + size_t size = static_cast(x.values[base_ + i].v_int64); + if (size > 0) { + w.work_size[arg_index_map_[i]] = size; + } } return w; } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index cb3b5d42e553..d30e7873dae7 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -845,7 +845,7 @@ def test_any_softmax(): verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) -def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): +def verify_any_topk(data_shape, kval, np_dshape, dtype, ret_type="indices", const_k=False): mod = tvm.IRModule() data = relay.var("data", shape=data_shape, dtype=dtype) np_data = np.random.uniform(size=np_dshape).astype(dtype) @@ -857,7 +857,9 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): k = relay.var("k", shape=(), dtype="int32") args = [data, k] in_vals = [np_data, kval] - out = relay.topk(data, k, ret_type="indices") + out = relay.topk(data, k, ret_type=ret_type) + if ret_type == "both": + out = out[0] mod["main"] = relay.Function(args, out) sorted = np.argsort(-np_data) @@ -873,7 +875,8 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") verify_any_topk(any_dims(2), 2, (6, 3), "int32") - verify_any_topk(any_dims(2), 3, (6, 3), "float32", True) + verify_any_topk(any_dims(2), 3, (6, 3), "float32", const_k=True) + verify_any_topk(any_dims(1), 0, (0,), "float32", ret_type="both") @tvm.testing.uses_gpu