Skip to content

Commit

Permalink
[TEST] Check return value of atomic_rmw (#4479)
Browse files Browse the repository at this point in the history
There doesn't seem to be any test coverage for this.

Co-authored-by: Thomas Raoux <[email protected]>
  • Loading branch information
int3 and ThomasRaoux authored Aug 7, 2024
1 parent c1fbfc5 commit 0c32752
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,26 +1464,33 @@ def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
# triton kernel

@triton.jit
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
if AXIS == 1:
tl.atomic_add(Z + off0, z)
old = tl.atomic_add(Z + off0, z)
tl.store(OLD + off0, old)
else:
tl.atomic_add(Z + off1, z)
old = tl.atomic_add(Z + off1, z)
tl.store(OLD + off1, old)

rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs)
# reference result
z_ref = np.sum(x, axis=axis, keepdims=False)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
# reference results
z_ref = z + np.sum(x, axis=axis, keepdims=False)
old_ref = np.copy(z)
# triton result
x_tri = to_triton(x, device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype=getattr(np, dtype_x_str)), device=device)
kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas)
z_tri = to_triton(z, device=device)
old_tri = to_triton(old, device=device)
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
np.testing.assert_equal(old_ref, to_numpy(old_tri))


@pytest.mark.interpreter
Expand Down

0 comments on commit 0c32752

Please sign in to comment.