Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add L-BFGS optimizer #6053

Merged
merged 3 commits into from
Jun 2, 2021
Merged

add L-BFGS optimizer #6053

merged 3 commits into from
Jun 2, 2021

Conversation

Jakob-Unfried
Copy link
Contributor

@Jakob-Unfried Jakob-Unfried commented Mar 13, 2021

As discussed in #1400

  • Type check
  • Rolling Buffer for histories (added TODO for now)
  • someone double check the math (@shoyer @Joshuaalbert ?)
  • Compare histories to scipy

@google-cla google-cla bot added the cla: yes label Mar 13, 2021
@Jakob-Unfried
Copy link
Contributor Author

Jakob-Unfried commented Mar 13, 2021

The typecheck is comlaining about possibly using >= on None.
However, I check for Nones in optimize.lbfgs line 112, so this will not be problem.

I would like to keep None as the default argument for maxiter, maxfun and maxgrad since it intuitively means "no restriction".

How should I adress this issue?

@fehiepsi
Copy link
Contributor

fehiepsi commented Mar 13, 2021

@Jakob-Unfried Could you test if the implementation works with the following simple test (from this discussion)

import pdb
import time
import copy
import scipy.optimize
import numpy as np
import jax.numpy as jnp
import jax.scipy.optimize
jax.config.update("jax_enable_x64", True)

def rosenbrock(x):
    answer = sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
    return answer

evalFunc = rosenbrock
jEvalFunc = jax.jit(evalFunc)

x0_1e2 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e2]
x0_1e3 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e3]
x0_1e4 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e4]
x0s = [x0_1e2, x0_1e3, x0_1e4]

toleranceGrad = 1e-5
toleranceChange = 1e-9
xConvTol = 1e-6
lineSearchFn = "strong_wolfe"
maxIter = 1000

for x0 in x0s:
    jnpTrueMin = jnp.ones(len(x0))
    minimizeOptions = {'gtol': toleranceGrad, 'maxiter': maxIter}
    jx0 = jnp.array(x0)
    tStart = time.time()
    optimRes = jax.scipy.optimize.minimize(fun=jEvalFunc, x0=jx0, method='BFGS', options=minimizeOptions)
    elapsedTime = time.time()-tStart

    print("\tjax.scipy.optimize")
    print("\t\tConverged: {}".format(jnp.linalg.norm(optimRes.x-jnpTrueMin, ord=2)<xConvTol))
    print("\t\tDiff: {}".format(jnp.linalg.norm(optimRes.x-jnpTrueMin, ord=2)))
    print("\t\tFunction evaluations: {:d}".format(optimRes.nfev))
    print("\t\tIterations: {:d}".format(optimRes.nit))
    print("\t\tElapsedTime: {}\n\n".format(elapsedTime))

The current BFGS method failed with x0_1e3 case, so I hope that LBFGS will resolve the issue.

@Jakob-Unfried
Copy link
Contributor Author

Jakob-Unfried commented Mar 13, 2021

@fehiepsi

I did not take the time to understand whats going on here and just blindly ran your snippet (using method='L-BFGS' and removing the closure [i guess that is for some other library...?])

This is the output

jax.scipy.optimize
		Converged: True
		Diff: 5.995648082766192e-07
		Function evaluations: 157
		Iterations: 110
		ElapsedTime: 2.1565237045288086


	jax.scipy.optimize
		Converged: False
		Diff: 2.7307485928955455e-06
		Function evaluations: 353
		Iterations: 257
		ElapsedTime: 2.411688804626465


	jax.scipy.optimize
		Converged: False
		Diff: 3.85740797801385e-06
		Function evaluations: 713
		Iterations: 511
		ElapsedTime: 2.5120770931243896

@fehiepsi
Copy link
Contributor

Oh yes, the closure stuff is from pytorch. The result looks great! Awesome work, @Jakob-Unfried!

Copy link
Collaborator

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks very nice, though I haven't checked the math carefully yet!

jax/_src/scipy/optimize/lbfgs.py Outdated Show resolved Hide resolved
jax/_src/scipy/optimize/lbfgs.py Outdated Show resolved Hide resolved
@shoyer
Copy link
Collaborator

shoyer commented Mar 14, 2021

The typecheck is comlaining about possibly using >= on None.
However, I check for Nones in optimize.lbfgs line 112, so this will not be problem.

I would like to keep None as the default argument for maxiter, maxfun and maxgrad since it intuitively means "no restriction".

How should I adress this issue?

I would suggest disabling type checking on that line by adding a comment # type: ignore

@shoyer
Copy link
Collaborator

shoyer commented Mar 14, 2021

With regards to verifying consistency with SciPy, what about comparing the state history (s and y) after a few steps? I know our line search algorithm is different, but maybe there is a simple case where line search should give the same result? E.g., perhaps some quadratic function with a non-diagonal Hessian?

@Jakob-Unfried
Copy link
Contributor Author

With regards to verifying consistency with SciPy, what about comparing the state history (s and y) after a few steps? I know our line search algorithm is different, but maybe there is a simple case where line search should give the same result? E.g., perhaps some quadratic function with a non-diagonal Hessian?

The example would only need to constructed such that 1. is an acceptable step-length at all steps,
then the issue of different line-searches does not matter.
I will come up with something.

For my own future reference: The histories can be extracted from scipy via result.hess_inv.sk and result.hess_inv.yk

@Jakob-Unfried
Copy link
Contributor Author

Jakob-Unfried commented Mar 16, 2021

@shoyer

With regards to verifying consistency with SciPy, what about comparing the state history (s and y) after a few steps? I know our line search algorithm is different, but maybe there is a simple case where line search should give the same result? E.g., perhaps some quadratic function with a non-diagonal Hessian?

It don't think my current implementation can be compared to scipy directly.
For reference the s-history are the steps s_k == x_{k+1} - x_k and the y-history are differences in gradient y_k == g_{k+1}-g_k
I tried the following things:

  • Just comparing for the "sphere" function lambda x: np.dot(x, x) starting from [1. 1.], scipy gives an s-history with a single entry [-.5, -.5],
    which is not the step to the minimum at [0. 0.]. My implementation gives the expected [-1. -1.]
  • A quadratic form np.dot(x, np.dot(A, x)) with the arbitrary positive matrix A==array([[1., 0.8], [0.6, 1.5]]), (true minimum [0. 0.], starting point [1. 1.]) i get the following s-histories (oldest to newest)
    scipy:
    [[-0.6114475  -0.791285  ]
     [-0.2797986  -0.28719221]
     [-0.03014821  0.01816222]
     [-0.07861546  0.06028725]]
    
    jax / my implementation:
    [[-0.8540882  -1.1052905 ]
     [-0.03715794  0.02681332]
     [-0.10875393  0.07847734]]
    
    (where i removed zeros, corresponding to unfilled entries in my implementation)
    Here the steps from scipy add up to the correct step of [-1. -1.] in total...

However, in both cases the minima agree with each other and with the true minimum up to floating precision (~5e-8).

This could be caused by the choice of gamma.
I don't exactly know whats going on here, especially since the scipy implementation is some 4000 lines of fortran, (to be fair, many comments) which i do not plan on dealing with.

@Joshuaalbert
Copy link
Contributor

@Jakob-Unfried There is a possible other test you could do. That would be to set maxiter=m and compare to BFGS. L-BFGS and BFGS are identical on the first m iterations when gamma=1. Since gamma is not always one, you can try a problem that has approximately both a) stepsize=1, and b) the trajectory "momentum" and gradients are parallel. Try the loss f(x)=x@x. In this case, the point should roll-down the the smooth symmetric potential with step-size 1 from the start.

@Jakob-Unfried
Copy link
Contributor Author

@Joshuaalbert
x@x isnt much use, since the first step directly finds the minimum, so no non-trivial L-BFGS update happens

@Joshuaalbert
Copy link
Contributor

True, that's probably always going to be the case when gamma=1 exactly, maybe a shallower basin like `(x@x)**(1.1/2.0)? The symmetric aspect is important, and that the step size is close to one. It may be that there is no good example of this sort, since whenever the step size is close to one, it converges in only a few steps.

@Joshuaalbert
Copy link
Contributor

BTW, the PR math seems good to me.

@shoyer
Copy link
Collaborator

shoyer commented Mar 26, 2021

Do we want to wait on additional tests here or just go ahead? I think this is probably close enough...

@Joshuaalbert
Copy link
Contributor

I support moving forward and letting the community vet it if anything strange should pop up.

@gnecula
Copy link
Collaborator

gnecula commented Mar 29, 2021

@shoyer Should I go ahead with merging this?

@mblondel
Copy link

mblondel commented Mar 30, 2021

It would be great to compare the performance with SciPy's solver on a standard machine learning problem, say regularized binary logistic regression (aka binary cross-entropy loss) or whatever problem is easier for you. For comparison with SciPy, you can use autograd to calculate the gradients.

from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100000, n_features=200, n_classes=2, random_state=0)

@salcats
Copy link

salcats commented Apr 14, 2021

Hi,

Just an interested potential user here.

Is there any idea of when this might be merged?

@Jakob-Unfried
Copy link
Contributor Author

@salcats

I am quite busy atm, so I can not spend much time on further tests.

I am quite confident that the code is well-behaved. I have been using its "prototype" for half a year now with good results.

So unless someone else is willing to invest time. into further tests in the near future, I suggest merging as is. @shoyer ?

@nikikilbertus
Copy link

Hi, I'm also really keen on using this. PR LGTM. @shoyer can we merge?

@Jakob-Unfried
Copy link
Contributor Author

@gnecula @shoyer
Anything I can do to see this merged sometime soon?
I still could not come up with a meaningful way of comparing to scipy...

@mblondel
Copy link

@Jakob-Unfried How about measuring the l2 norm of ||∇f(x_t)|| for maxiter = 50, 100, 150, ...? Then you can plot the sequence of points (runtime, error) for scipy and jax.

@shoyer shoyer added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels May 10, 2021
@mblondel
Copy link

Here's the objective of binary logistic regression (not tested):

from sklearn.datasets import make_classification
from jax.nn import softplus

X, y = make_classification(n_samples=100000, n_features=200, n_classes=2, random_state=0)
binary_logreg = jax.vmap(lambda label, logit: softplus(logit) - label * logit)
lam = 1.0

def fun(w):
  logits = jnp.dot(X, w)
  return binary_logreg(y, logits) + 0.5 * lam * jnp.dot(w ** 2)

w_init = jnp.zeros(X.shape[1])
res = jax.scipy.optimize.minimize(fun=fun, x0=w_init, method='LBFGS')

@shoyer
Copy link
Collaborator

shoyer commented May 10, 2021

If somebody else wants to work on validation, I would be happy to merge this into JAX to facilitate that work, without documenting it or exposing it in the public API. E.g., we could remove it from jax.scipy.optimize.minimize or require using some absurd method name like method='l-bfgs-experimental-do-not-rely-on-this?

@Joshuaalbert
Copy link
Contributor

That's a good approach @shoyer.

@Jakob-Unfried
Copy link
Contributor Author

@shoyer ,
i like this approach.
feel free to go ahead with it, otherwise i will push sth tonight.

@mblondel
the plot seems to be a good idea, i will look into it

@copybara-service copybara-service bot merged commit ab52da0 into jax-ml:master Jun 2, 2021
@acclaude
Copy link

acclaude commented Jun 4, 2021

Hi, I tried to replicate the speed test performance :

Joshua George Albert, https://gist.github.com/Joshuaalbert/214f14bbdd55d413693b8b413a384cae
#1400
96265837-3dd14c00-0fc6-11eb-9e33-3aba006e6a7f

unfortunately i did not find the same results
Figure_1

Have a nice weekend

PS : AMD Ryzen 7 3700X

@Joshuaalbert
Copy link
Contributor

@acclaude Oh wow interesting, I'm getting the same result now. I'm not sure what version of jax and jaxlib I was using before for that but suspect performance degraded at some point. I will do a regression analysis against jax versions and see if I can find where that degradation happened.

@Joshuaalbert
Copy link
Contributor

Looks like I'll need to test since Jul 29, 2020, which was jax-0.1.74 all the way up to now.

@Joshuaalbert
Copy link
Contributor

I found it @acclaude @mattjj @jakevdp @shoyer!

For some reason BFGS becomes slow at jax-0.2.6 and jaxlib-0.1.57. Any ideas why?

jax-0 2 5_jaxlib-0 1 56

speed_results_jax-0 2 5_jaxlib-0 1 56

jax-0 2 6_jaxlib-0 1 57
speed_results_jax-0 2 6_jaxlib-0 1 57

You can run the regress test yourself here: https://gist.github.com/Joshuaalbert/39e5e44a06cb00e7e154b37504e30fa1

@Joshuaalbert
Copy link
Contributor

I will also note that something else performance wise happened at jax-0.2.11 and jaxlib-0.1.64 as evidenced by the speed up of scipy+jitted(func) below:

jax-0 2 10_jaxlib-0 1 61
speed_results_jax-0 2 10_jaxlib-0 1 61

jax-0 2 11_jaxlib-0 1 64
speed_results_jax-0 2 11_jaxlib-0 1 64

@Joshuaalbert
Copy link
Contributor

Joshuaalbert commented Jun 4, 2021

Note, @acclaude this is likely not the right place to report performance problems. I'll let others traffic this conversation elsewhere though, and I'm happy to continue discussion there.

@acclaude
Copy link

acclaude commented Jun 4, 2021

Thank you very much for this interesting analysis. I wasn't sure where to post these performance issues. I will continue to do performance tests on my side.
Thank you

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 10, 2023

Hi all - it seems like there hasn't been any followup work on this, and the l-bfgs-experimental-do-not-rely-on-this method is still poorly tested.

Are there any objections to removing this code from JAX? Or is there someone willing to take on support of this method and graduating it to something more reliable?

(note: the JAXOpt package has well-supported routines for this kind of minimization, so that would be a good alternative)

@Jakob-Unfried
Copy link
Contributor Author

I can not spend any time on this, unfortunately.
And since there is JAXOpt, i dont htink it makes sense to prioritize this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants