-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add L-BFGS optimizer #6053
add L-BFGS optimizer #6053
Conversation
The typecheck is comlaining about possibly using I would like to keep How should I adress this issue? |
@Jakob-Unfried Could you test if the implementation works with the following simple test (from this discussion) import pdb
import time
import copy
import scipy.optimize
import numpy as np
import jax.numpy as jnp
import jax.scipy.optimize
jax.config.update("jax_enable_x64", True)
def rosenbrock(x):
answer = sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
return answer
evalFunc = rosenbrock
jEvalFunc = jax.jit(evalFunc)
x0_1e2 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e2]
x0_1e3 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e3]
x0_1e4 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e4]
x0s = [x0_1e2, x0_1e3, x0_1e4]
toleranceGrad = 1e-5
toleranceChange = 1e-9
xConvTol = 1e-6
lineSearchFn = "strong_wolfe"
maxIter = 1000
for x0 in x0s:
jnpTrueMin = jnp.ones(len(x0))
minimizeOptions = {'gtol': toleranceGrad, 'maxiter': maxIter}
jx0 = jnp.array(x0)
tStart = time.time()
optimRes = jax.scipy.optimize.minimize(fun=jEvalFunc, x0=jx0, method='BFGS', options=minimizeOptions)
elapsedTime = time.time()-tStart
print("\tjax.scipy.optimize")
print("\t\tConverged: {}".format(jnp.linalg.norm(optimRes.x-jnpTrueMin, ord=2)<xConvTol))
print("\t\tDiff: {}".format(jnp.linalg.norm(optimRes.x-jnpTrueMin, ord=2)))
print("\t\tFunction evaluations: {:d}".format(optimRes.nfev))
print("\t\tIterations: {:d}".format(optimRes.nit))
print("\t\tElapsedTime: {}\n\n".format(elapsedTime)) The current BFGS method failed with |
I did not take the time to understand whats going on here and just blindly ran your snippet (using This is the output
|
Oh yes, the closure stuff is from pytorch. The result looks great! Awesome work, @Jakob-Unfried! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks very nice, though I haven't checked the math carefully yet!
I would suggest disabling type checking on that line by adding a comment |
With regards to verifying consistency with SciPy, what about comparing the state history ( |
The example would only need to constructed such that For my own future reference: The histories can be extracted from scipy via |
It don't think my current implementation can be compared to scipy directly.
However, in both cases the minima agree with each other and with the true minimum up to floating precision (~ This could be caused by the choice of gamma. |
@Jakob-Unfried There is a possible other test you could do. That would be to set |
@Joshuaalbert |
True, that's probably always going to be the case when |
BTW, the PR math seems good to me. |
Do we want to wait on additional tests here or just go ahead? I think this is probably close enough... |
I support moving forward and letting the community vet it if anything strange should pop up. |
@shoyer Should I go ahead with merging this? |
It would be great to compare the performance with SciPy's solver on a standard machine learning problem, say regularized binary logistic regression (aka binary cross-entropy loss) or whatever problem is easier for you. For comparison with SciPy, you can use autograd to calculate the gradients. from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100000, n_features=200, n_classes=2, random_state=0) |
Hi, Just an interested potential user here. Is there any idea of when this might be merged? |
I am quite busy atm, so I can not spend much time on further tests. I am quite confident that the code is well-behaved. I have been using its "prototype" for half a year now with good results. So unless someone else is willing to invest time. into further tests in the near future, I suggest merging as is. @shoyer ? |
Hi, I'm also really keen on using this. PR LGTM. @shoyer can we merge? |
@Jakob-Unfried How about measuring the l2 norm of |
Here's the objective of binary logistic regression (not tested): from sklearn.datasets import make_classification
from jax.nn import softplus
X, y = make_classification(n_samples=100000, n_features=200, n_classes=2, random_state=0)
binary_logreg = jax.vmap(lambda label, logit: softplus(logit) - label * logit)
lam = 1.0
def fun(w):
logits = jnp.dot(X, w)
return binary_logreg(y, logits) + 0.5 * lam * jnp.dot(w ** 2)
w_init = jnp.zeros(X.shape[1])
res = jax.scipy.optimize.minimize(fun=fun, x0=w_init, method='LBFGS') |
If somebody else wants to work on validation, I would be happy to merge this into JAX to facilitate that work, without documenting it or exposing it in the public API. E.g., we could remove it from |
That's a good approach @shoyer. |
Hi, I tried to replicate the speed test performance : Joshua George Albert, https://gist.github.com/Joshuaalbert/214f14bbdd55d413693b8b413a384cae unfortunately i did not find the same results Have a nice weekend PS : AMD Ryzen 7 3700X |
@acclaude Oh wow interesting, I'm getting the same result now. I'm not sure what version of jax and jaxlib I was using before for that but suspect performance degraded at some point. I will do a regression analysis against jax versions and see if I can find where that degradation happened. |
Looks like I'll need to test since Jul 29, 2020, which was jax-0.1.74 all the way up to now. |
I found it @acclaude @mattjj @jakevdp @shoyer! For some reason BFGS becomes slow at jax-0 2 5_jaxlib-0 1 56 You can run the regress test yourself here: https://gist.github.com/Joshuaalbert/39e5e44a06cb00e7e154b37504e30fa1 |
Note, @acclaude this is likely not the right place to report performance problems. I'll let others traffic this conversation elsewhere though, and I'm happy to continue discussion there. |
Thank you very much for this interesting analysis. I wasn't sure where to post these performance issues. I will continue to do performance tests on my side. |
Hi all - it seems like there hasn't been any followup work on this, and the Are there any objections to removing this code from JAX? Or is there someone willing to take on support of this method and graduating it to something more reliable? (note: the JAXOpt package has well-supported routines for this kind of minimization, so that would be a good alternative) |
I can not spend any time on this, unfortunately. |
As discussed in #1400