Skip to content

Commit

Permalink
Update test_cross to test broadcastable shapes
Browse files Browse the repository at this point in the history
This also updates it to only test axes from [min(x1.ndim, x2.ndim), -1], as
per data-apis/array-api#740
  • Loading branch information
cr313 committed Feb 3, 2024
1 parent 4403061 commit 22a248e
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,29 @@ def cross_args(draw, dtype_objects=dh.real_dtypes):
in the drawn axis.
"""
shape = list(draw(shapes()))
size = len(shape)
assume(size > 0)
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
min_ndim = min(len(shape1), len(shape2))
assume(min_ndim > 0)

kw = draw(kwargs(axis=integers(-size, size-1)))
kw = draw(kwargs(axis=integers(-min_ndim, -1)))
axis = kw.get('axis', -1)
shape[axis] = 3
shape = tuple(shape)
if draw(booleans()):
# Sometimes allow invalid inputs to test it errors
shape1 = list(shape1)
shape1[axis] = 3
shape1 = tuple(shape1)
shape2 = list(shape2)
shape2[axis] = 3
shape2 = tuple(shape2)

mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects))
arrays1 = arrays(
dtype=mutual_dtypes.map(lambda pair: pair[0]),
shape=shape,
shape=shape1,
)
arrays2 = arrays(
dtype=mutual_dtypes.map(lambda pair: pair[1]),
shape=shape,
shape=shape2,
)
return draw(arrays1), draw(arrays2), kw

Expand All @@ -176,15 +182,17 @@ def test_cross(x1_x2_kw):
x1, x2, kw = x1_x2_kw

axis = kw.get('axis', -1)
err = "test_cross produced invalid input. This indicates a bug in the test suite."
assert x1.shape == x2.shape, err
shape = x1.shape
assert x1.shape[axis] == x2.shape[axis] == 3, err
if not (x1.shape[axis] == x2.shape[axis] == 3):
ph.raises(Exception, lambda: xp.cross(x1, x2, **kw),
"cross did not raise an exception for invalid shapes")
return

res = linalg.cross(x1, x2, **kw)

broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)

assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
assert res.shape == shape, "cross() did not return the correct shape"
assert res.shape == broadcasted_shape, "cross() did not return the correct shape"

def exact_cross(a, b):
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
Expand Down

0 comments on commit 22a248e

Please sign in to comment.