diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index 728cd682fa42..a09a5f3f4ae3 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -141,7 +141,7 @@ def mix(a, b, rotation): return [x, y] # temporary buffer for holding the results of _PERMUTATIONS - tmp = irb.allocate(out_buf.dtype, out_shape, name="tmp", scope="global") + tmp = irb.allocate(out_buf.dtype, out_shape * nwords, name="tmp", scope="global") tmp_offset = 0 # Initialize entire key. It is composed of the original key with one diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py index 649e5410c147..102e93f3b245 100644 --- a/tests/python/topi/python/test_topi_prng.py +++ b/tests/python/topi/python/test_topi_prng.py @@ -87,9 +87,9 @@ def test_threefry_generate(target, ctx): gen = tvm.relay.random.threefry_key(0).data.asnumpy() # check that we can generate some data - a, rands = threefry_generate(target, ctx, gen, (100,)) + a, rands = threefry_generate(target, ctx, gen, (2048,)) assert ( - rands.shape[0] == 100 and len(rands.shape) == 1 + rands.shape[0] == 2048 and len(rands.shape) == 1 ), "Output shape should match requested shape" # check that gen out does not equal input @@ -99,13 +99,13 @@ def test_threefry_generate(target, ctx): gen = np.array( [0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 1 << 63, 0], dtype="uint64" ) # make counter large - a, rands = threefry_generate(target, ctx, gen, (100,)) + a, rands = threefry_generate(target, ctx, gen, (2048,)) assert gen[4] != a[4], "Overflow of counter should trigger path change" - assert a[7] == 100, "Overflow of counter should still update counter" + assert a[7] == 2048, "Overflow of counter should still update counter" # check generate with path at length limit gen = np.array([0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 0, 0], dtype="uint64") # make counter large - a, rands = threefry_generate(target, ctx, gen, (100,)) + a, rands = threefry_generate(target, ctx, gen, (2048,)) assert ( gen[0:4] != a[0:4] ).any(), "Overflowing counter with no space left in path should change state"