diff --git a/docs/python-api/triton.testing.rst b/docs/python-api/triton.testing.rst index 824e10c6f511..c89b0ba4280b 100644 --- a/docs/python-api/triton.testing.rst +++ b/docs/python-api/triton.testing.rst @@ -11,3 +11,4 @@ triton.testing do_bench do_bench_cudagraph perf_report + assert_close diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d94e3ce2f884..20a41c8dd48a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1440,11 +1440,13 @@ def kernel(X): @pytest.mark.interpreter -@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) - for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] - for axis in [0, 1] - for num_ctas in num_ctas_list]) -def test_tensor_atomic_rmw(shape, axis, num_ctas, device): +@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str", + [(shape, axis, num_ctas, dtype_x_str) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list + for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): shape0, shape1 = shape # triton kernel @@ -1460,13 +1462,13 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr) tl.atomic_add(Z + off1, z) rs = RandomState(17) - x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs) # reference result z_ref = np.sum(x, axis=axis, keepdims=False) # triton result x_tri = to_triton(x, device=device) z_shape = (shape0, ) if axis == 1 else (shape1, ) - z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) + z_tri = to_triton(np.zeros(z_shape, dtype=getattr(np, dtype_x_str)), device=device) kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3163305536bf..51be31a34622 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1729,12 +1729,13 @@ def _decorator(func: T) -> T: docstr += """ :param val: The values with which to perform the atomic operation :type val: Block of dtype=pointer.dtype.element_ty - :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), - "ACQUIRE", "RELEASE", or "RELAXED") - :type sem: str - :param scope: Scope of threads that observe synchronizing effect of the - atomic operation ("GPU" (default), "CTA", or "SYSTEM") - :type scope: str + :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire", + "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, + the function defaults to using "acq_rel" semantics. + :type sem: str, optional + :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation. + Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + :type scope: str, optional """ func.__doc__ = docstr return func diff --git a/python/triton/testing.py b/python/triton/testing.py index 0b938184a6c2..39229e2e7aa3 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -161,6 +161,20 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ import numpy as np import torch