Skip to content

Commit

Permalink
[FIX] Fix temporary allocation size in threefry (#7709)
Browse files Browse the repository at this point in the history
* [FIX] Fix temporary allocation size in threefry

* bump sizes
  • Loading branch information
tkonolige committed Mar 23, 2021
1 parent f88c2be commit 6f0a656
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/random/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def mix(a, b, rotation):
return [x, y]

# temporary buffer for holding the results of _PERMUTATIONS
tmp = irb.allocate(out_buf.dtype, out_shape, name="tmp", scope="global")
tmp = irb.allocate(out_buf.dtype, out_shape * nwords, name="tmp", scope="global")
tmp_offset = 0

# Initialize entire key. It is composed of the original key with one
Expand Down
10 changes: 5 additions & 5 deletions tests/python/topi/python/test_topi_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_threefry_generate(target, ctx):
gen = tvm.relay.random.threefry_key(0).data.asnumpy()

# check that we can generate some data
a, rands = threefry_generate(target, ctx, gen, (100,))
a, rands = threefry_generate(target, ctx, gen, (2048,))
assert (
rands.shape[0] == 100 and len(rands.shape) == 1
rands.shape[0] == 2048 and len(rands.shape) == 1
), "Output shape should match requested shape"

# check that gen out does not equal input
Expand All @@ -99,13 +99,13 @@ def test_threefry_generate(target, ctx):
gen = np.array(
[0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 1 << 63, 0], dtype="uint64"
) # make counter large
a, rands = threefry_generate(target, ctx, gen, (100,))
a, rands = threefry_generate(target, ctx, gen, (2048,))
assert gen[4] != a[4], "Overflow of counter should trigger path change"
assert a[7] == 100, "Overflow of counter should still update counter"
assert a[7] == 2048, "Overflow of counter should still update counter"

# check generate with path at length limit
gen = np.array([0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 0, 0], dtype="uint64") # make counter large
a, rands = threefry_generate(target, ctx, gen, (100,))
a, rands = threefry_generate(target, ctx, gen, (2048,))
assert (
gen[0:4] != a[0:4]
).any(), "Overflowing counter with no space left in path should change state"
Expand Down

0 comments on commit 6f0a656

Please sign in to comment.