Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unit test failures on aarch64 linux with scipy 1.12 #577

Open
ghost opened this issue Feb 4, 2024 · 0 comments
Open

unit test failures on aarch64 linux with scipy 1.12 #577

ghost opened this issue Feb 4, 2024 · 0 comments

Comments

@ghost
Copy link

ghost commented Feb 4, 2024

jaxopt 0.8.3 with patch #574 is failing the following unit tests after updating to scipy 1.12

=================================== FAILURES ===================================
__________________ LbfgsTest.test_binary_logit_log_likelihood __________________
[gw45] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11
self = <lbfgs_test.LbfgsTest testMethod=test_binary_logit_log_likelihood>
    def test_binary_logit_log_likelihood(self):
      # See issue #409
      rng = jax.random.PRNGKey(42)
      N = 1000
      beta = jnp.array([[0.5,0.5]]).T
      income = jax.random.normal(rng, shape=(N,1))
      x = jnp.hstack([jnp.ones((N,1)), income])
    
      def simulate_binary_logit(x, beta):
        beta = beta.reshape(-1,1)
        N = x.shape[0]
        J = beta.shape[0]
    
        epsilon = jax.random.gumbel(rng,shape =(N,J))
        Beta_augmented = jnp.hstack([beta, jnp.zeros_like(beta)])
        utility = x @ Beta_augmented + epsilon
    
        choice_idx = onp.argmax(utility, axis=1)
        return (choice_idx).reshape(-1,1)
    
      y = simulate_binary_logit(x, beta)
      y = jnp.ravel(y)
    
      # numpy version
      def binary_logit_log_likelihood(beta, y,x):
        lambda_xb = onp.exp(x@beta) / (1 + onp.exp(x@beta))
        ll_i = y * onp.log(lambda_xb) + (1-y) * onp.log(1-lambda_xb)
        ll = -onp.sum(ll_i)
        return ll
    
      # jax version
      def binary_logit_log_likelihood_jax(beta, y, x):
        lambda_xb = jnp.exp(x@beta) / (1 + jnp.exp(x@beta))
        ll_i = y * jnp.log(lambda_xb) + (1-y) * jnp.log(1-lambda_xb)
        ll = -jnp.sum(ll_i)
        return ll
    
      beta_init = jnp.array([0.01,0.01])
    
      # using scipy
      scipy_res = scipy_opt.minimize(
        fun=binary_logit_log_likelihood,
        args=(onp.asarray(y),onp.asarray(x)),
        x0 = (onp.asarray(beta_init)), method='BFGS'
      ).x
    
      # using jaxopt
      solver = LBFGS(fun=binary_logit_log_likelihood_jax, maxiter=100,
                     linesearch="zoom", maxls=10, tol=1e-12)
      jaxopt_res = solver.run(beta_init, y, x).params
    
      # comparison
      scipy_val = binary_logit_log_likelihood(scipy_res,
                                              onp.asarray(y),
                                              onp.asarray(x))
      jaxopt_val = binary_logit_log_likelihood(jaxopt_res, y, x)
>     self.assertArraysAllClose(scipy_val, jaxopt_val)
tests/lbfgs_test.py:422: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jaxopt/_src/test_util.py:292: in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
args = (<function assert_allclose.<locals>.compare at 0xfffeb82b56c0>, array(636.76217796), array(636.7615, dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'verbose': True}
    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=1e-06
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 0.00070335
E           Max relative difference: 1.10457124e-06
E            x: array(636.762178)
E            y: array(636.7615, dtype=float32)
/nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/contextlib.py:81: AssertionError
______________________ LinearSolveTest.test_solve_sparse _______________________
[gw32] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11
self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse>
    def test_solve_sparse(self):
      rng = onp.random.RandomState(0)
    
      # Matrix case.
      A = rng.randn(5, 5)
      b = rng.randn(5)
    
      def matvec(x):
        return jnp.dot(A, x)
    
      x = linear_solve.solve_lu(matvec, b)
      x2 = linear_solve.solve_normal_cg(matvec, b)
      x3 = linear_solve.solve_gmres(matvec, b)
      x4 = linear_solve.solve_bicgstab(matvec, b)
      x5 = linear_solve.solve_iterative_refinement(matvec, b)
      x6 = linear_solve.solve_qr(matvec, b)
    
      self.assertArraysAllClose(x, x2, atol=1e-4)
>     self.assertArraysAllClose(x, x3, atol=1e-4)
tests/linear_solve_test.py:132: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jaxopt/_src/test_util.py:292: in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
args = (<function assert_allclose.<locals>.compare at 0xfffe946f58a0>, array([-6.9443398, -1.9871643,  7.747069 ,  7.654946 ,...875  ],
      dtype=float32), array([-6.944449 , -1.9872042,  7.747199 ,  7.655077 , -7.038865 ],
      dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True}
    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=0.0001
E           
E           Mismatched elements: 4 / 5 (80%)
E           Max absolute difference: 0.00013113
E           Max relative difference: 2.009613e-05
E            x: array([-6.94434 , -1.987164,  7.747069,  7.654946, -7.03875 ],
E                 dtype=float32)
E            y: array([-6.944449, -1.987204,  7.747199,  7.655077, -7.038865],
E                 dtype=float32)
/nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/contextlib.py:81: AssertionError
____________ PolyakSgdTest.test_logreg_with_intercept_manual_loop3 _____________
[gw13] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11
self = <polyak_sgd_test.PolyakSgdTest testMethod=test_logreg_with_intercept_manual_loop3>
momentum = 0.9, sps_variant = 'SPS+'
    @parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+'])
    def test_logreg_with_intercept_manual_loop(self, momentum, sps_variant):
      x, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3,
                                          n_informative=3, random_state=0)
      data = (x, y)
      l2reg = 0.1
      # fun(params, l2reg, data)
      fun = objective.l2_multiclass_logreg_with_intercept
      n_classes = len(jnp.unique(y))
    
      w_init = jnp.zeros((x.shape[1], n_classes))
      b_init = jnp.zeros(n_classes)
      params = (w_init, b_init)
    
      opt = PolyakSGD(
          fun=fun, fun_min=0.6975, momentum=momentum, variant=sps_variant
      )
      error_init = opt.l2_optimality_error(params, l2reg=l2reg, data=data)
    
      state = opt.init_state(params, l2reg=l2reg, data=data)
      for _ in range(200):
        params, state = opt.update(params, state, l2reg=l2reg, data=data)
    
      # Check optimality conditions.
      error = opt.l2_optimality_error(params, l2reg=l2reg, data=data)
>     self.assertLessEqual(error / error_init, 0.02)
E     AssertionError: Array(0.02369377, dtype=float32) not less than or equal to 0.02
tests/polyak_sgd_test.py:79: AssertionError
=============================== warnings summary ===============================
jaxopt/_src/osqp.py:299
  /build/source/jaxopt/_src/osqp.py:299: DeprecationWarning: invalid escape sequence '\m'
    """Operator Splitting Solver for Quadratic Programs.
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn0
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn0
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn1
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn1
tests/isotonic_test.py::IsotonicPavTest::test_output_shape_and_dtype
tests/isotonic_test.py::IsotonicPavTest::test_vmap
tests/isotonic_test.py::IsotonicPavTest::test_gradient1
tests/isotonic_test.py::IsotonicPavTest::test_gradient0
tests/isotonic_test.py::IsotonicPavTest::test_gradient_min_max
  /build/source/jaxopt/_src/isotonic.py:94: UserWarning: Numba could not be imported. Code will run much more slowly. To install, run 'pip install numba'.
    warnings.warn(
tests/lbfgs_test.py::LbfgsTest::test_against_scipy1
tests/lbfgs_test.py::LbfgsTest::test_against_scipy3
tests/lbfgs_test.py::LbfgsTest::test_against_scipy0
tests/lbfgs_test.py::LbfgsTest::test_against_scipy4
tests/lbfgs_test.py::LbfgsTest::test_minimize_bad_initial_values
tests/lbfgs_test.py::LbfgsTest::test_against_scipy2
  /build/source/jaxopt/_src/lbfgs.py:119: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    fun = lambda leaf: jnp.zeros((history_size,) + leaf.shape, dtype=leaf.dtype)
tests/linear_solve_test.py::LinearSolveTest::test_solve_1d
tests/linear_solve_test.py::LinearSolveTest::test_solve_dense
tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse
tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse_ridge
  /build/source/jaxopt/_src/linear_solve.py:31: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    x = jnp.zeros(shape, dtype)
tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x327
  /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver inv that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True.
    warnings.warn(f"The linear solver {self.solver} that requires materialization of "
tests/loss_test.py::LossTest::test_multiclass_logistic_loss
tests/loss_test.py::LossTest::test_multiclass_sparsemax_loss
  /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/ops/scatter.py:96: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
    warnings.warn(
tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x325
  /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver lu that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True.
    warnings.warn(f"The linear solver {self.solver} that requires materialization of "
tests/common_test.py::CommonTest::test_dtype_consistency
  /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/lax/lax.py:2385: RuntimeWarning: overflow encountered in cast
    out = np.array(c, eqn.params['new_dtype'])
tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev0
  /build/source/jaxopt/_src/scipy_wrappers.py:343: RuntimeWarning: Method Nelder-Mead does not use gradient information (jac).
    res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev1
  /build/source/jaxopt/_src/scipy_wrappers.py:343: RuntimeWarning: Method Powell does not use gradient information (jac).
    res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev2
  /build/source/jaxopt/_src/scipy_wrappers.py:343: OptimizeWarning: Unknown solver options: maxiter
    res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x324
  /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver cholesky that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True.
    warnings.warn(f"The linear solver {self.solver} that requires materialization of "
tests/scipy_wrappers_test.py::ScipyRootFindingTest::test_broyden
  /nix/store/1z0wr5pb0ckj88qy92mwh7zkc0yaym80-python3.11-scipy-1.12.0/lib/python3.11/site-packages/scipy/optimize/_root.py:245: RuntimeWarning: Method broyden1 does not use the jacobian (jac).
    _warn_jac_unused(jac, method)
tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable
tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable
tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable
  /nix/store/ndvyzqskd5yqzybwfpqk1dyc9qp2k00f-python3.11-scikit-learn-1.4.0/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
    warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/lbfgs_test.py::LbfgsTest::test_binary_logit_log_likelihood - AssertionError: 
FAILED tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse - AssertionError: 
FAILED tests/polyak_sgd_test.py::PolyakSgdTest::test_logreg_with_intercept_manual_loop3 - AssertionError: Array(0.02369377, dtype=float32) not less than or equal to ...
============ 3 failed, 552 passed, 6 skipped, 33 warnings in 49.90s ============
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

0 participants