We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxopt 0.8.3 with patch #574 is failing the following unit tests after updating to scipy 1.12
=================================== FAILURES =================================== __________________ LbfgsTest.test_binary_logit_log_likelihood __________________ [gw45] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11 self = <lbfgs_test.LbfgsTest testMethod=test_binary_logit_log_likelihood> def test_binary_logit_log_likelihood(self): # See issue #409 rng = jax.random.PRNGKey(42) N = 1000 beta = jnp.array([[0.5,0.5]]).T income = jax.random.normal(rng, shape=(N,1)) x = jnp.hstack([jnp.ones((N,1)), income]) def simulate_binary_logit(x, beta): beta = beta.reshape(-1,1) N = x.shape[0] J = beta.shape[0] epsilon = jax.random.gumbel(rng,shape =(N,J)) Beta_augmented = jnp.hstack([beta, jnp.zeros_like(beta)]) utility = x @ Beta_augmented + epsilon choice_idx = onp.argmax(utility, axis=1) return (choice_idx).reshape(-1,1) y = simulate_binary_logit(x, beta) y = jnp.ravel(y) # numpy version def binary_logit_log_likelihood(beta, y,x): lambda_xb = onp.exp(x@beta) / (1 + onp.exp(x@beta)) ll_i = y * onp.log(lambda_xb) + (1-y) * onp.log(1-lambda_xb) ll = -onp.sum(ll_i) return ll # jax version def binary_logit_log_likelihood_jax(beta, y, x): lambda_xb = jnp.exp(x@beta) / (1 + jnp.exp(x@beta)) ll_i = y * jnp.log(lambda_xb) + (1-y) * jnp.log(1-lambda_xb) ll = -jnp.sum(ll_i) return ll beta_init = jnp.array([0.01,0.01]) # using scipy scipy_res = scipy_opt.minimize( fun=binary_logit_log_likelihood, args=(onp.asarray(y),onp.asarray(x)), x0 = (onp.asarray(beta_init)), method='BFGS' ).x # using jaxopt solver = LBFGS(fun=binary_logit_log_likelihood_jax, maxiter=100, linesearch="zoom", maxls=10, tol=1e-12) jaxopt_res = solver.run(beta_init, y, x).params # comparison scipy_val = binary_logit_log_likelihood(scipy_res, onp.asarray(y), onp.asarray(x)) jaxopt_val = binary_logit_log_likelihood(jaxopt_res, y, x) > self.assertArraysAllClose(scipy_val, jaxopt_val) tests/lbfgs_test.py:422: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ jaxopt/_src/test_util.py:292: in assertArraysAllClose _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) jaxopt/_src/test_util.py:262: in _assert_numpy_allclose onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ args = (<function assert_allclose.<locals>.compare at 0xfffeb82b56c0>, array(636.76217796), array(636.7615, dtype=float32)) kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'verbose': True} @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): > return func(*args, **kwds) E AssertionError: E Not equal to tolerance rtol=1e-06, atol=1e-06 E E Mismatched elements: 1 / 1 (100%) E Max absolute difference: 0.00070335 E Max relative difference: 1.10457124e-06 E x: array(636.762178) E y: array(636.7615, dtype=float32) /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/contextlib.py:81: AssertionError ______________________ LinearSolveTest.test_solve_sparse _______________________ [gw32] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11 self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse> def test_solve_sparse(self): rng = onp.random.RandomState(0) # Matrix case. A = rng.randn(5, 5) b = rng.randn(5) def matvec(x): return jnp.dot(A, x) x = linear_solve.solve_lu(matvec, b) x2 = linear_solve.solve_normal_cg(matvec, b) x3 = linear_solve.solve_gmres(matvec, b) x4 = linear_solve.solve_bicgstab(matvec, b) x5 = linear_solve.solve_iterative_refinement(matvec, b) x6 = linear_solve.solve_qr(matvec, b) self.assertArraysAllClose(x, x2, atol=1e-4) > self.assertArraysAllClose(x, x3, atol=1e-4) tests/linear_solve_test.py:132: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ jaxopt/_src/test_util.py:292: in assertArraysAllClose _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) jaxopt/_src/test_util.py:262: in _assert_numpy_allclose onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ args = (<function assert_allclose.<locals>.compare at 0xfffe946f58a0>, array([-6.9443398, -1.9871643, 7.747069 , 7.654946 ,...875 ], dtype=float32), array([-6.944449 , -1.9872042, 7.747199 , 7.655077 , -7.038865 ], dtype=float32)) kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True} @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): > return func(*args, **kwds) E AssertionError: E Not equal to tolerance rtol=1e-06, atol=0.0001 E E Mismatched elements: 4 / 5 (80%) E Max absolute difference: 0.00013113 E Max relative difference: 2.009613e-05 E x: array([-6.94434 , -1.987164, 7.747069, 7.654946, -7.03875 ], E dtype=float32) E y: array([-6.944449, -1.987204, 7.747199, 7.655077, -7.038865], E dtype=float32) /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/contextlib.py:81: AssertionError ____________ PolyakSgdTest.test_logreg_with_intercept_manual_loop3 _____________ [gw13] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11 self = <polyak_sgd_test.PolyakSgdTest testMethod=test_logreg_with_intercept_manual_loop3> momentum = 0.9, sps_variant = 'SPS+' @parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+']) def test_logreg_with_intercept_manual_loop(self, momentum, sps_variant): x, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3, n_informative=3, random_state=0) data = (x, y) l2reg = 0.1 # fun(params, l2reg, data) fun = objective.l2_multiclass_logreg_with_intercept n_classes = len(jnp.unique(y)) w_init = jnp.zeros((x.shape[1], n_classes)) b_init = jnp.zeros(n_classes) params = (w_init, b_init) opt = PolyakSGD( fun=fun, fun_min=0.6975, momentum=momentum, variant=sps_variant ) error_init = opt.l2_optimality_error(params, l2reg=l2reg, data=data) state = opt.init_state(params, l2reg=l2reg, data=data) for _ in range(200): params, state = opt.update(params, state, l2reg=l2reg, data=data) # Check optimality conditions. error = opt.l2_optimality_error(params, l2reg=l2reg, data=data) > self.assertLessEqual(error / error_init, 0.02) E AssertionError: Array(0.02369377, dtype=float32) not less than or equal to 0.02 tests/polyak_sgd_test.py:79: AssertionError =============================== warnings summary =============================== jaxopt/_src/osqp.py:299 /build/source/jaxopt/_src/osqp.py:299: DeprecationWarning: invalid escape sequence '\m' """Operator Splitting Solver for Quadratic Programs. tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn0 tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn0 tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn1 tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn1 tests/isotonic_test.py::IsotonicPavTest::test_output_shape_and_dtype tests/isotonic_test.py::IsotonicPavTest::test_vmap tests/isotonic_test.py::IsotonicPavTest::test_gradient1 tests/isotonic_test.py::IsotonicPavTest::test_gradient0 tests/isotonic_test.py::IsotonicPavTest::test_gradient_min_max /build/source/jaxopt/_src/isotonic.py:94: UserWarning: Numba could not be imported. Code will run much more slowly. To install, run 'pip install numba'. warnings.warn( tests/lbfgs_test.py::LbfgsTest::test_against_scipy1 tests/lbfgs_test.py::LbfgsTest::test_against_scipy3 tests/lbfgs_test.py::LbfgsTest::test_against_scipy0 tests/lbfgs_test.py::LbfgsTest::test_against_scipy4 tests/lbfgs_test.py::LbfgsTest::test_minimize_bad_initial_values tests/lbfgs_test.py::LbfgsTest::test_against_scipy2 /build/source/jaxopt/_src/lbfgs.py:119: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. fun = lambda leaf: jnp.zeros((history_size,) + leaf.shape, dtype=leaf.dtype) tests/linear_solve_test.py::LinearSolveTest::test_solve_1d tests/linear_solve_test.py::LinearSolveTest::test_solve_dense tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse_ridge /build/source/jaxopt/_src/linear_solve.py:31: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. x = jnp.zeros(shape, dtype) tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x327 /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver inv that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True. warnings.warn(f"The linear solver {self.solver} that requires materialization of " tests/loss_test.py::LossTest::test_multiclass_logistic_loss tests/loss_test.py::LossTest::test_multiclass_sparsemax_loss /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/ops/scatter.py:96: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error. warnings.warn( tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x325 /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver lu that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True. warnings.warn(f"The linear solver {self.solver} that requires materialization of " tests/common_test.py::CommonTest::test_dtype_consistency /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/lax/lax.py:2385: RuntimeWarning: overflow encountered in cast out = np.array(c, eqn.params['new_dtype']) tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev0 /build/source/jaxopt/_src/scipy_wrappers.py:343: RuntimeWarning: Method Nelder-Mead does not use gradient information (jac). res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev1 /build/source/jaxopt/_src/scipy_wrappers.py:343: RuntimeWarning: Method Powell does not use gradient information (jac). res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev2 /build/source/jaxopt/_src/scipy_wrappers.py:343: OptimizeWarning: Unknown solver options: maxiter res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x324 /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver cholesky that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True. warnings.warn(f"The linear solver {self.solver} that requires materialization of " tests/scipy_wrappers_test.py::ScipyRootFindingTest::test_broyden /nix/store/1z0wr5pb0ckj88qy92mwh7zkc0yaym80-python3.11-scipy-1.12.0/lib/python3.11/site-packages/scipy/optimize/_root.py:245: RuntimeWarning: Method broyden1 does not use the jacobian (jac). _warn_jac_unused(jac, method) tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable /nix/store/ndvyzqskd5yqzybwfpqk1dyc9qp2k00f-python3.11-scikit-learn-1.4.0/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html =========================== short test summary info ============================ FAILED tests/lbfgs_test.py::LbfgsTest::test_binary_logit_log_likelihood - AssertionError: FAILED tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse - AssertionError: FAILED tests/polyak_sgd_test.py::PolyakSgdTest::test_logreg_with_intercept_manual_loop3 - AssertionError: Array(0.02369377, dtype=float32) not less than or equal to ... ============ 3 failed, 552 passed, 6 skipped, 33 warnings in 49.90s ============
The text was updated successfully, but these errors were encountered:
No branches or pull requests
jaxopt 0.8.3 with patch #574 is failing the following unit tests after updating to scipy 1.12
The text was updated successfully, but these errors were encountered: