diff --git a/tests/api_test.py b/tests/api_test.py index 1e3ca34e0776..cbf60c320974 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -55,6 +55,9 @@ FLAGS = config.FLAGS +numpy_version = tuple(map(int, np.__version__.split('.')[:3])) + + class CPPJitTest(jtu.BufferDonationTestCase): """Shared tests between the Python and the C++ jax,jit implementations. @@ -1484,6 +1487,8 @@ def test_grad_of_int_errors(self): r"sub-dtype of np.floating or np.complexfloating\), but got int.*."), lambda: dfn(3)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_jvp_of_int_identity(self): primals = (1,) tangents = (np.zeros(shape=(), dtype=float0),) @@ -1491,6 +1496,8 @@ def test_jvp_of_int_identity(self): _, out = api.jvp(lambda x: x, primals, tangents) self.assertEqual(out, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_jvp_of_int_add(self): primals = (2,) tangents = (np.zeros(shape=(), dtype=float0),) @@ -1498,6 +1505,8 @@ def test_jvp_of_int_add(self): _, out_tangent = api.jvp(lambda x: x+1, primals, tangents) self.assertEqual(out_tangent, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_jit_jvp_of_int(self): primals = (2,) tangents = (np.zeros(shape=(), dtype=float0),) @@ -1505,6 +1514,8 @@ def test_jit_jvp_of_int(self): _, out_tangent = api.jvp(jax.jit(lambda x: x+1), primals, tangents) self.assertEqual(out_tangent, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_vjp_of_int_index(self): primal, fn_vjp = api.vjp(lambda x, i: x[i], np.ones(2)*2, 1) tangent_x, tangent_i = fn_vjp(1.) @@ -1512,12 +1523,16 @@ def test_vjp_of_int_index(self): self.assertAllClose(tangent_x, jnp.array([0., 1.])) self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_vjp_of_int_shapes(self): out, fn_vjp = api.vjp(lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1), dtype=int)) tangent, = fn_vjp(out) self.assertArraysEqual(tangent, np.zeros(shape=(4, 1), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_jit_vjp_of_int(self): primal, fn_vjp = api.vjp(lambda x, y: x+y, 2, 1) tangent_x, tangent_i = jax.jit(fn_vjp)(1) @@ -1525,6 +1540,8 @@ def test_jit_vjp_of_int(self): self.assertEqual(tangent_x, np.zeros(shape=(), dtype=float0)) self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_vjp_of_int_fulllike(self): # Regression test for tangent and cotangent mismatch in convert_element_type # transpose rule wrt a ConstVar @@ -1535,11 +1552,15 @@ def test_vjp_of_int_fulllike(self): self.assertAllClose(tangent_x, jnp.zeros((2, 2))) self.assertEqual(tangent_y, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_grad_of_int(self): # Need real-valued output, but testing integer input. out = api.grad(lambda x: x+0., allow_int=True)(1) self.assertEqual(out, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_grad_of_bool(self): def cond(pred): return lax.cond(pred, lambda _: 1., lambda _: 2., 1.) @@ -1547,18 +1568,24 @@ def cond(pred): self.assertEqual(value, 1.) self.assertEqual(grd, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_grad_of_int_index(self): grad_x, grad_i = api.grad(lambda x, i: x[i], argnums=(0, 1), allow_int=True)(np.ones(2), 1) self.assertAllClose(grad_x, jnp.array([0., 1.])) self.assertEqual(grad_i, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_jit_grad_of_int(self): grad_f = api.grad(lambda x, i: x[i], argnums=(0, 1), allow_int=True) grad_x, grad_i = jax.jit(grad_f)(np.ones(2), 1) self.assertAllClose(grad_x, jnp.array([0., 1.])) self.assertEqual(grad_i, np.zeros(shape=(), dtype=float0)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_float0_reshape(self): # dtype-agnostic operations are supported float0_array = jax.grad(lambda x: jnp.sum(x+0.), @@ -3894,6 +3921,8 @@ def f_jvp(primal, tangent): api.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3)) # don't crash + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_float0(self): @api.custom_jvp def f(x, y): @@ -3909,6 +3938,8 @@ def f_jvp(primals, _): self.assertArraysEqual(api.jvp(f, primals, tangents), (primals, expected_tangents)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_float0_initial_style(self): @api.custom_jvp def f(x, y): @@ -4837,6 +4868,8 @@ def bwd(y, g): ans = jax.grad(f)(4.) self.assertAllClose(ans, -2. * jnp.sin(4.)) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_float0(self): @api.custom_vjp def f(x, _): @@ -4853,6 +4886,8 @@ def f_rev(*_): self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), (2., np.zeros(shape=(), dtype=float0))) + @unittest.skipIf(numpy_version == (1, 21, 0), + "https://github.com/numpy/numpy/issues/19305") def test_float0_initial_style(self): @api.custom_vjp def f(x):