Skip to content

Commit

Permalink
Ignore signbit testing if not available
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Feb 26, 2024
1 parent a257446 commit 2cfaeb7
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,25 +422,30 @@ def assert_fill(


def _real_float_strict_equals(out: Array, expected: Array) -> bool:
assert hasattr(_xp, "signbit") # sanity check

nan_mask = xp.isnan(out)
if not xp.all(nan_mask == xp.isnan(expected)):
return False
ignore_mask = nan_mask

# Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
# not that big of a deal for the perf costs.
if hasattr(_xp, "signbit"):
out_zero_mask = out == 0
out_sign_mask = xp.signbit(out)
out_pos_zero_mask = out_zero_mask & out_sign_mask
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
expected_zero_mask = expected == 0
expected_sign_mask = xp.signbit(expected)
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)):
return False
ignore_mask |= out_zero_mask

out_zero_mask = out == 0
out_sign_mask = xp.signbit(out)
out_pos_zero_mask = out_zero_mask & out_sign_mask
out_neg_zero_mask = out_zero_mask & ~out_sign_mask
expected_zero_mask = expected == 0
expected_sign_mask = xp.signbit(expected)
expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
if not (xp.all(out_pos_zero_mask == expected_pos_zero_mask) and xp.all(out_neg_zero_mask == expected_neg_zero_mask)):
return False

ignore_mask = nan_mask | out_zero_mask
replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
assert replacement == replacement # sanity check
match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
return xp.all(match)

Expand Down

0 comments on commit 2cfaeb7

Please sign in to comment.