Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable float0 tests that fail under NumPy 1.21. #7096

Merged
merged 1 commit into from
Jun 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
FLAGS = config.FLAGS


numpy_version = tuple(map(int, np.__version__.split('.')[:3]))


class CPPJitTest(jtu.BufferDonationTestCase):
"""Shared tests between the Python and the C++ jax,jit implementations.

Expand Down Expand Up @@ -1484,47 +1487,61 @@ def test_grad_of_int_errors(self):
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
lambda: dfn(3))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_jvp_of_int_identity(self):
primals = (1,)
tangents = (np.zeros(shape=(), dtype=float0),)

_, out = api.jvp(lambda x: x, primals, tangents)
self.assertEqual(out, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_jvp_of_int_add(self):
primals = (2,)
tangents = (np.zeros(shape=(), dtype=float0),)

_, out_tangent = api.jvp(lambda x: x+1, primals, tangents)
self.assertEqual(out_tangent, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_jit_jvp_of_int(self):
primals = (2,)
tangents = (np.zeros(shape=(), dtype=float0),)

_, out_tangent = api.jvp(jax.jit(lambda x: x+1), primals, tangents)
self.assertEqual(out_tangent, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_vjp_of_int_index(self):
primal, fn_vjp = api.vjp(lambda x, i: x[i], np.ones(2)*2, 1)
tangent_x, tangent_i = fn_vjp(1.)
self.assertEqual(primal, 2.)
self.assertAllClose(tangent_x, jnp.array([0., 1.]))
self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_vjp_of_int_shapes(self):
out, fn_vjp = api.vjp(lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1),
dtype=int))
tangent, = fn_vjp(out)
self.assertArraysEqual(tangent, np.zeros(shape=(4, 1), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_jit_vjp_of_int(self):
primal, fn_vjp = api.vjp(lambda x, y: x+y, 2, 1)
tangent_x, tangent_i = jax.jit(fn_vjp)(1)
self.assertEqual(primal, 3)
self.assertEqual(tangent_x, np.zeros(shape=(), dtype=float0))
self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_vjp_of_int_fulllike(self):
# Regression test for tangent and cotangent mismatch in convert_element_type
# transpose rule wrt a ConstVar
Expand All @@ -1535,30 +1552,40 @@ def test_vjp_of_int_fulllike(self):
self.assertAllClose(tangent_x, jnp.zeros((2, 2)))
self.assertEqual(tangent_y, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_int(self):
# Need real-valued output, but testing integer input.
out = api.grad(lambda x: x+0., allow_int=True)(1)
self.assertEqual(out, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_bool(self):
def cond(pred):
return lax.cond(pred, lambda _: 1., lambda _: 2., 1.)
value, grd = api.value_and_grad(cond, allow_int=True)(True)
self.assertEqual(value, 1.)
self.assertEqual(grd, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_int_index(self):
grad_x, grad_i = api.grad(lambda x, i: x[i], argnums=(0, 1),
allow_int=True)(np.ones(2), 1)
self.assertAllClose(grad_x, jnp.array([0., 1.]))
self.assertEqual(grad_i, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_jit_grad_of_int(self):
grad_f = api.grad(lambda x, i: x[i], argnums=(0, 1), allow_int=True)
grad_x, grad_i = jax.jit(grad_f)(np.ones(2), 1)
self.assertAllClose(grad_x, jnp.array([0., 1.]))
self.assertEqual(grad_i, np.zeros(shape=(), dtype=float0))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_float0_reshape(self):
# dtype-agnostic operations are supported
float0_array = jax.grad(lambda x: jnp.sum(x+0.),
Expand Down Expand Up @@ -3894,6 +3921,8 @@ def f_jvp(primal, tangent):

api.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3)) # don't crash

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_float0(self):
@api.custom_jvp
def f(x, y):
Expand All @@ -3909,6 +3938,8 @@ def f_jvp(primals, _):
self.assertArraysEqual(api.jvp(f, primals, tangents),
(primals, expected_tangents))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_float0_initial_style(self):
@api.custom_jvp
def f(x, y):
Expand Down Expand Up @@ -4837,6 +4868,8 @@ def bwd(y, g):
ans = jax.grad(f)(4.)
self.assertAllClose(ans, -2. * jnp.sin(4.))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_float0(self):
@api.custom_vjp
def f(x, _):
Expand All @@ -4853,6 +4886,8 @@ def f_rev(*_):
self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y),
(2., np.zeros(shape=(), dtype=float0)))

@unittest.skipIf(numpy_version == (1, 21, 0),
"https://github.com/numpy/numpy/issues/19305")
def test_float0_initial_style(self):
@api.custom_vjp
def f(x):
Expand Down