Skip to content

Commit

Permalink
[DOCS] Improve docs and minor fixes for testing.py and core.py (#4104)
Browse files Browse the repository at this point in the history
Update the docs and a test case to cover more dtypes.
  • Loading branch information
lancerts authored Jun 10, 2024
1 parent 6251a6e commit 71b8d33
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/python-api/triton.testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ triton.testing
do_bench
do_bench_cudagraph
perf_report
assert_close
16 changes: 9 additions & 7 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,11 +1440,13 @@ def kernel(X):


@pytest.mark.interpreter
@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas)
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
for axis in [0, 1]
for num_ctas in num_ctas_list])
def test_tensor_atomic_rmw(shape, axis, num_ctas, device):
@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str",
[(shape, axis, num_ctas, dtype_x_str)
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
for axis in [0, 1]
for num_ctas in num_ctas_list
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
shape0, shape1 = shape
# triton kernel

Expand All @@ -1460,13 +1462,13 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr)
tl.atomic_add(Z + off1, z)

rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs)
# reference result
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
x_tri = to_triton(x, device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
z_tri = to_triton(np.zeros(z_shape, dtype=getattr(np, dtype_x_str)), device=device)
kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)

Expand Down
13 changes: 7 additions & 6 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,12 +1729,13 @@ def _decorator(func: T) -> T:
docstr += """
:param val: The values with which to perform the atomic operation
:type val: Block of dtype=pointer.dtype.element_ty
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
"ACQUIRE", "RELEASE", or "RELAXED")
:type sem: str
:param scope: Scope of threads that observe synchronizing effect of the
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
:type scope: str
:param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire",
"release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided,
the function defaults to using "acq_rel" semantics.
:type sem: str, optional
:param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation.
Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
:type scope: str, optional
"""
func.__doc__ = docstr
return func
Expand Down
14 changes: 14 additions & 0 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu


def assert_close(x, y, atol=None, rtol=None, err_msg=''):
"""
Asserts that two inputs are close within a certain tolerance.
:param x: The first input.
:type x: scala, list, numpy.ndarray, or torch.Tensor
:param y: The second input.
:type y: scala, list, numpy.ndarray, or torch.Tensor
:param atol: The absolute tolerance. Default value is 1e-2.
:type atol: float, optional
:param rtol: The relative tolerance. Default value is 0.
:type rtol: float, optional
:param err_msg: The error message to use if the assertion fails.
:type err_msg: str
"""
import numpy as np
import torch

Expand Down

0 comments on commit 71b8d33

Please sign in to comment.