Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short-circuit with vectorisation ph.assert_array_elements() #236

Merged
merged 7 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion array-api
Submodule array-api updated 116 files
4 changes: 4 additions & 0 deletions array_api_tests/meta/test_pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ def test_assert_array_elements():
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
with raises(AssertionError):
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))

ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan")))
with raises(AssertionError):
ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0))
68 changes: 56 additions & 12 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import dtype_helpers as dh
from . import shape_helpers as sh
from . import stubs
from . import xp as _xp
from .typing import Array, DataType, Scalar, ScalarType, Shape

__all__ = [
Expand Down Expand Up @@ -420,6 +421,35 @@ def assert_fill(
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg


def _real_float_strict_equals(out: Array, expected: Array) -> bool:
nan_mask = xp.isnan(out)
if not xp.all(nan_mask == xp.isnan(expected)):
return False
ignore_mask = nan_mask

# Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
# not that big of a deal for the perf costs.
if hasattr(_xp, "signbit"):
out_zero_mask = out == 0
out_sign_mask = _xp.signbit(out)
out_pos_zero_mask = out_zero_mask & out_sign_mask
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
expected_zero_mask = expected == 0
expected_sign_mask = _xp.signbit(expected)
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)):
return False
ignore_mask |= out_zero_mask

replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
assert replacement == replacement # sanity check
match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
return xp.all(match)


def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
if xp.isnan(at_expected):
assert xp.isnan(at_out), msg
Expand Down Expand Up @@ -455,31 +485,45 @@ def assert_array_elements(
>>> assert xp.all(out == x)

"""
__tracebackhide__ = True
# __tracebackhide__ = True
dh.result_type(out.dtype, expected.dtype) # sanity check
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
f_func = f"[{func_name}({fmt_kw(kw)})]"

# First we try short-circuit for a successful assertion by using vectorised checks.
if out.dtype in dh.real_float_dtypes:
if _real_float_strict_equals(out, expected):
return
elif out.dtype in dh.complex_dtypes:
real_match = _real_float_strict_equals(xp.real(out), xp.real(expected))
imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected))
if real_match and imag_match:
return
else:
match = out == expected
if xp.all(match):
return

# In case of mismatch, generate a more helpful error. Cycling through all indices is
# costly in some array api implementations, so we only do this in the case of a failure.
msg_template = "{}={}, but should be {} " + f_func
if out.dtype in dh.real_float_dtypes:
for idx in sh.ndindex(out.shape):
at_out = out[idx]
at_expected = expected[idx]
msg = (
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
f"{f_func}"
)
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
_assert_float_element(at_out, at_expected, msg)
elif out.dtype in dh.complex_dtypes:
assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
for idx in sh.ndindex(out.shape):
at_out = out[idx]
at_expected = expected[idx]
msg = (
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
f"{f_func}"
)
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
else:
assert xp.all(
out == expected
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
for idx in sh.ndindex(out.shape):
at_out = out[idx]
at_expected = expected[idx]
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
assert at_out == at_expected, msg
16 changes: 8 additions & 8 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,14 @@ def test_eye(n_rows, n_cols, kw):
ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype)
_n_cols = n_rows if n_cols is None else n_cols
ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols))
f_func = f"[eye({n_rows=}, {n_cols=})]"
for i in range(n_rows):
for j in range(_n_cols):
f_indexed_out = f"out[{i}, {j}]={out[i, j]}"
if j - i == kw.get("k", 0):
assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}"
else:
assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}"
k = kw.get("k", 0)
expected = xp.asarray(
[[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)],
dtype=out.dtype # Note: dtype already checked above.
)
if expected.size == 0:
expected = xp.reshape(expected, (n_rows, _n_cols))
ph.assert_array_elements("eye", out=out, expected=expected, kw=kw)


default_unsafe_dtypes = [xp.uint64]
Expand Down
Loading