From eb100236c0d434b233a927d996cc2d464d47ac91 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 13 Nov 2023 10:17:32 -0800 Subject: [PATCH 1/7] Make assert_array_elements more efficient in the non-error case --- array_api_tests/pytest_helpers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index e6ede7b2..6d3899a9 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -459,6 +459,13 @@ def assert_array_elements( dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" + + match = (out == expected) + if xp.all(match): + return + + # In case of mismatch, generate a more helpful error. Cycling through all indices is + # costly in some array api implementations, so we only do this in the case of a failure. if out.dtype in dh.real_float_dtypes: for idx in sh.ndindex(out.shape): at_out = out[idx] @@ -480,6 +487,4 @@ def assert_array_elements( _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: - assert xp.all( - out == expected - ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" + assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" From 557e6217299d82c7e355d35176da22b8938ea818 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 13 Nov 2023 10:18:01 -0800 Subject: [PATCH 2/7] test_eye: use assert_array_elements utility --- array_api_tests/test_creation_functions.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 19e945ca..ec2df060 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw): ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) _n_cols = n_rows if n_cols is None else n_cols ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) - f_func = f"[eye({n_rows=}, {n_cols=})]" - for i in range(n_rows): - for j in range(_n_cols): - f_indexed_out = f"out[{i}, {j}]={out[i, j]}" - if j - i == kw.get("k", 0): - assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}" - else: - assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}" + k = kw.get("k", 0) + expected = xp.asarray( + [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], + dtype=out.dtype # Note: dtype already checked above. + ) + if expected.size == 0: + expected = xp.reshape(expected, (n_rows, _n_cols)) + ph.assert_array_elements("eye", out=out, expected=expected, kw=kw) default_unsafe_dtypes = [xp.uint64] From c697fda4a6c872405231eb72304ef2fa5449f9f6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 26 Feb 2024 13:07:16 +0000 Subject: [PATCH 3/7] Bump `array-api` 2023.12 release! --- array-api | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array-api b/array-api index ab69aa24..ea6a47f0 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit ab69aa240025ff1d52525ce3859b69ebfd6b7faf +Subproject commit ea6a47f03e0aa26a9b17e70deba12e4096cfd2f3 From 2cb2e5d279a6e55eedd685d30f65e70bbb4b521a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 26 Feb 2024 13:09:25 +0000 Subject: [PATCH 4/7] Use vectorised checks in `assert_array_elements` --- array_api_tests/meta/test_pytest_helpers.py | 4 ++ array_api_tests/pytest_helpers.py | 60 ++++++++++++++++----- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 17ed5534..a32c6f33 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -20,3 +20,7 @@ def test_assert_array_elements(): ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0)) with raises(AssertionError): ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0)) + + ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan"))) + with raises(AssertionError): + ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0)) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 6d3899a9..174d27d1 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -6,7 +6,8 @@ from . import _array_module as xp from . import dtype_helpers as dh from . import shape_helpers as sh -from . import stubs +from . import stubs, api_version +from . import xp as _xp from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ @@ -420,6 +421,30 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def _real_float_strict_equals(out: Array, expected: Array) -> bool: + assert hasattr(_xp, "signbit") # sanity check + + nan_mask = xp.isnan(out) + if not xp.all(nan_mask == xp.isnan(expected)): + return False + + out_zero_mask = out == 0 + out_sign_mask = xp.signbit(out) + out_pos_zero_mask = out_zero_mask & out_sign_mask + out_neg_zero_mask = out_zero_mask & ~out_sign_mask + expected_zero_mask = expected == 0 + expected_sign_mask = xp.signbit(expected) + expected_pos_zero_mask = expected_zero_mask & expected_sign_mask + expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask + if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)): + return False + + ignore_mask = nan_mask | out_zero_mask + replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself + match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected) + return xp.all(match) + + def _assert_float_element(at_out: Array, at_expected: Array, msg: str): if xp.isnan(at_expected): assert xp.isnan(at_out), msg @@ -460,31 +485,40 @@ def assert_array_elements( assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" - match = (out == expected) - if xp.all(match): - return + # First we try short-circuit for a successful assertion by using vectorised checks. + if out.dtype in dh.real_float_dtypes and api_version >= "2023.12": + if _real_float_strict_equals(out, expected): + return + elif out.dtype in dh.complex_dtypes and api_version >= "2023.12": + real_match = _real_float_strict_equals(out.real, expected.real) + imag_match = _real_float_strict_equals(out.imag, expected.imag) + if real_match and imag_match: + return + else: + match = out == expected + if xp.all(match): + return # In case of mismatch, generate a more helpful error. Cycling through all indices is # costly in some array api implementations, so we only do this in the case of a failure. + msg_template = "{}={}, but should be {} " + f_func if out.dtype in dh.real_float_dtypes: for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] - msg = ( - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " - f"{f_func}" - ) + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) _assert_float_element(at_out, at_expected, msg) elif out.dtype in dh.complex_dtypes: assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes) for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] - msg = ( - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " - f"{f_func}" - ) + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: - assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) + assert at_out == at_expected, msg From a0fde7b574d6e2b3af22c7e47a878e7fee8d8ac0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 26 Feb 2024 16:23:46 +0000 Subject: [PATCH 5/7] Ignore `signbit` testing if not available --- array_api_tests/pytest_helpers.py | 37 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 174d27d1..9b559010 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -422,25 +422,30 @@ def assert_fill( def _real_float_strict_equals(out: Array, expected: Array) -> bool: - assert hasattr(_xp, "signbit") # sanity check - nan_mask = xp.isnan(out) if not xp.all(nan_mask == xp.isnan(expected)): return False + ignore_mask = nan_mask + + # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's + # not that big of a deal for the perf costs. + if api_version >= "2023.12" and hasattr(_xp, "signbit"): + out_zero_mask = out == 0 + out_sign_mask = xp.signbit(out) + out_pos_zero_mask = out_zero_mask & out_sign_mask + out_neg_zero_mask = out_zero_mask & ~out_sign_mask + expected_zero_mask = expected == 0 + expected_sign_mask = xp.signbit(expected) + expected_pos_zero_mask = expected_zero_mask & expected_sign_mask + expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask + pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask + neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask + if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)): + return False + ignore_mask |= out_zero_mask - out_zero_mask = out == 0 - out_sign_mask = xp.signbit(out) - out_pos_zero_mask = out_zero_mask & out_sign_mask - out_neg_zero_mask = out_zero_mask & ~out_sign_mask - expected_zero_mask = expected == 0 - expected_sign_mask = xp.signbit(expected) - expected_pos_zero_mask = expected_zero_mask & expected_sign_mask - expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask - if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)): - return False - - ignore_mask = nan_mask | out_zero_mask replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself + assert replacement == replacement # sanity check match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected) return xp.all(match) @@ -486,10 +491,10 @@ def assert_array_elements( f_func = f"[{func_name}({fmt_kw(kw)})]" # First we try short-circuit for a successful assertion by using vectorised checks. - if out.dtype in dh.real_float_dtypes and api_version >= "2023.12": + if out.dtype in dh.real_float_dtypes: if _real_float_strict_equals(out, expected): return - elif out.dtype in dh.complex_dtypes and api_version >= "2023.12": + elif out.dtype in dh.complex_dtypes: real_match = _real_float_strict_equals(out.real, expected.real) imag_match = _real_float_strict_equals(out.imag, expected.imag) if real_match and imag_match: From 0e980e6f66a1f35989ef647dbb6bd179e90e15bd Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 26 Feb 2024 17:08:43 +0000 Subject: [PATCH 6/7] Use functional `real()`/`imag()` APIs `real` and `imag` are not properties of API arrays! --- array_api_tests/pytest_helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 9b559010..14c40668 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -485,7 +485,7 @@ def assert_array_elements( >>> assert xp.all(out == x) """ - __tracebackhide__ = True + # __tracebackhide__ = True dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" @@ -495,8 +495,8 @@ def assert_array_elements( if _real_float_strict_equals(out, expected): return elif out.dtype in dh.complex_dtypes: - real_match = _real_float_strict_equals(out.real, expected.real) - imag_match = _real_float_strict_equals(out.imag, expected.imag) + real_match = _real_float_strict_equals(xp.real(out), xp.real(expected)) + imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected)) if real_match and imag_match: return else: From 82125d12f8f53730723ee9b0f92fd04d7be1f8e5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 27 Feb 2024 16:22:08 +0000 Subject: [PATCH 7/7] Always use `signbit` if available --- array_api_tests/pytest_helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 14c40668..ead9fc6e 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -6,7 +6,7 @@ from . import _array_module as xp from . import dtype_helpers as dh from . import shape_helpers as sh -from . import stubs, api_version +from . import stubs from . import xp as _xp from .typing import Array, DataType, Scalar, ScalarType, Shape @@ -429,13 +429,13 @@ def _real_float_strict_equals(out: Array, expected: Array) -> bool: # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's # not that big of a deal for the perf costs. - if api_version >= "2023.12" and hasattr(_xp, "signbit"): + if hasattr(_xp, "signbit"): out_zero_mask = out == 0 - out_sign_mask = xp.signbit(out) + out_sign_mask = _xp.signbit(out) out_pos_zero_mask = out_zero_mask & out_sign_mask out_neg_zero_mask = out_zero_mask & ~out_sign_mask expected_zero_mask = expected == 0 - expected_sign_mask = xp.signbit(expected) + expected_sign_mask = _xp.signbit(expected) expected_pos_zero_mask = expected_zero_mask & expected_sign_mask expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask