-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BFGS/Quasi-Newton optimizers? #1400
Comments
Yes, I would be really excited about this! In a non-stochastic setting, these are usually much more effective than stochastic gradient descent. We could even define gradients for these optimizers using implicit root finding (I.e., based on #1339) |
L-BFGS is particular would be really nice to have: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/lbfgs.py |
This would be great! I am interested in and happy to review the math of your PR. (another reference is pytorch lbfgs which is mostly based on https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html). |
Whoa - differentiable root finding is pretty whack! The PyTorch L-BFGS seems significantly more advanced than the tfp one - let me read them over to see if we have everything that's needed. |
I think the difference between two versions is: pytorch uses strong Wolfe line search (not more advanced) while tensorflow uses Hager Zhang line search. I would propose to stick with Hager Zhang line search if you are familiar with tensorflow code. :) |
It looks like scipy also use strong Wolfe line search. I'm sure this makes
a small difference, but it's all at the margin -- I'd be happy to have
either in JAX.
…On Fri, Sep 27, 2019 at 7:58 AM Du Phan ***@***.***> wrote:
I think the difference between two versions is: pytorch uses strong Wolfe
line search while tensorflow uses Hager Zhang line search. It seems to me
(without evidence) that Hager Zhang line search is better.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1400>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAJJFVTDORFT55V2LF4LXFDQLYNS7ANCNFSM4I227AGQ>
.
|
When I used the strong Wolfe line search to do bayesian optimization for some toy datasets, I faced instability issues (hessian blow up -> nan happens) near the optimum value. So I hope that Hager Zhang line search will do better. Looking at the documentation of Hager Zhang, it seems to deal with the issue I faced: "On a finite precision machine, the exact Wolfe conditions can be difficult to satisfy when one is very close to the minimum and as argued by [Hager and Zhang (2005)][1], one can only expect the minimum to be determined within square root of machine precision". |
@proteneer Just wondering if you are still working on this! I'm also interested in attempting an implementation. |
Hi @TuanNguyen27 - sorry this fell off the radar. I ended up implementing the FIRE optimizer instead based on @sschoenholz 's code. Feel free to take a stab yourself, as I'm not working on this right now. |
@shoyer @fehiepsi I've been staring at pytorch's implementation, but it contains too many parameters (both initial and auxiliary) to pack into |
optimizers.py is specialized to stochastic first-order optimizers; for second-order optimizers, don't feel constrained to follow its APIs. IMO just do what makes sense to you (while hopefully following the spirit of being functional and minimal). |
@TuanNguyen27 PyTorch LBFGS is implemented to have a similar interface as other stochastic optimizers. In practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I guess you don't need to follow other JAX optimizers, and no need to manage
or similar to build_ode
. But as Matt said, just do what makes sense to you. |
For JAX, I think it makes sense to target an interface like SciPy here.
…On Sat, Nov 16, 2019 at 6:57 PM Du Phan ***@***.***> wrote:
@TuanNguyen27 <https://github.com/TuanNguyen27> PyTorch LBFGS is
implemented to have a similar interface as other stochastic optimizers. In
practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I
guess you don't need to follow other optimizer APIs, and no need to manage
state variable except for internal while loops. In my view, the initial
API in JAX would be close to scipy lbfgs-b
<https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb>,
that is:
def lbfgs(fun, x0, maxtol=..., gtol=..., maxfun=..., maxiter=...):
return solution
or similar to build_ode
<https://github.com/google/jax/blob/master/jax/experimental/ode.py#L377>
def lbfgs(fun, maxtol=..., gtol=..., maxfun=..., maxiter=...):
return optimize_fun
solution = lbfgs(f)(x0)
. But as Matt said, just do what makes sense to you.
FYI, wikipedia has a nice list of test functions
<https://en.wikipedia.org/wiki/Test_functions_for_optimization> to verify
your implementation.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1400>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVSFSXTL3LJM733L4MDQUCXIRANCNFSM4I227AGQ>
.
|
Any news on this? I wrote an adaption of pytorchs L-BFGS implementation for my personal use (even supporting complex arguments) that works well for me. |
I'm also more than happy to help here @shoyer and @Jakob-Unfried |
It would be great for someone to start on this! Ideally it would be broken up into a number of smaller PRs, e.g., starting with line search, then adding L-BFGS, then making it differentiable via the implicit function theorem. I think the TF-probability implementation is probably the best place to start for a JAX port. The TF control flow ops are all functional and have close equivalents in JAX. The pytorch version looks fine but it’s very object oriented/stateful, which could make it harder to integrate into JAX. |
I have an implementation of L-BFGS with Strong Wolfe line search lying around. Just needs a few touch ups to make it nicer as well as documentation that people who are not me will understand. I will get it ready for someone to review, probably late this evening (european time) A couple of questions:
|
|
Let me know what you think about my implementation: https://github.com/Jakob-Unfried/jax/commit/b28300cbcc234cda0b40b16cf15678e3f78e4085 Open Questions:
|
@Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with |
Unfortunately, I am not (yet) very familiar with the requirements of I see two paths forward:
|
I'd also love to see an implementation of Levenberg-Marquardt. The Matlab ML toolkit includes an implementation and it's regularly used to train for timeseries (regression) RNN (NARX) models - it seems to perform much better than gradient descent. I'd love to be able to validate this and compare with other 2nd order solvers. I'll watch the progress of this thread with interest and try and contribute where I can and if/when I feel I understand enough I might give it a go. |
@Jakob-Unfried Let me suggest an iterative process:
This can and should happen over multiple PRs. As long you or others express willingness to help, I am happy to start on the process. As for "how to write code that works well with
|
Also happy to help ! |
@shoyer I will then get back here with a TODO list. @TuanNguyen27 |
@Jakob-Unfried sounds good, I just sent you a gitter DM to follow up |
What is the benefit of adding a minimizer? The minimizer calls the cost function and its gradient. Both can be JAX accelerated already and the functions dominate the computing time. I don't see how one could gain here. The power of JAX is that you can mix with other libraries. |
You absolutely can and should leverage optimizers from ScPy for many
applications. I'm interested in optimizers written in JAX for use cases
that involve nested optimization, like meta-optimization.
…On Fri, Feb 21, 2020 at 10:31 PM Hans Dembinski ***@***.***> wrote:
What is the benefit of adding a minimizer? The minimizer calls the cost
function and its gradient. Both can be JAX accelerated already and the
functions dominate the computing time. I don't see how one could gain here.
The power of JAX is that you can mix with other libraries.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1400>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVU3BAFMHLSSTCN7BDLREDBDDANCNFSM4I227AGQ>
.
|
The main reason to rewrite algorithms in JAX is to support JAX's function transformations. In this case, that would mostly be relevant for performance:
|
I see, thank you for the explanation! |
TFP-on-JAX now supports L-BFGS: !pip install -q tfp-nightly[jax]
from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
def rosenbrock(coord):
"""The Rosenbrock function in two dimensions with a=1, b=100.
Args:
coord: Array with shape [2]. The coordinate of the point to evaluate
the function at.
Returns:
fv: A scalar tensor containing the value of the Rosenbrock function at
the supplied point.
dcoord: Array with shape [2]. The derivative of the function with respect to
`coord`.
"""
x, y = coord[0], coord[1]
fv = (1 - x)**2 + 100 * (y - x**2)**2
dfx = 2 * (x - 1) + 400 * x * (x**2 - y)
dfy = 200 * (y - x**2)
return fv, jnp.stack([dfx, dfy])
start = jnp.float32([-1.2, 1.0])
results = tfp.optimizer.lbfgs_minimize(
rosenbrock, initial_position=start, tolerance=1e-5)
results.position # => DeviceArray([1.0000001, 1.0000002], dtype=float32) |
Also, in tomorrow's nightly build, BFGS and Nelder-Mead optimizers will be supported in TFP JAX. |
* BFGS algorithm Addressing #1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <[email protected]>
Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation. |
Thanks @shoyer for your hard work getting this in. I'll be happy to add some more. I'm interested in getting (L)BFGS in with bounded constraints. As well I can add pytree support to BFGS, and as promised in my code comment in line search, I still intend to profile linesearch between the versions using where's and cond's. Op 30 jul. 2020 3:24 a.m. schreef Stephan Hoyer <[email protected]>:
Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation.
—You are receiving this because you commented.Reply to this email directly, view it on GitHub, or unsubscribe.
|
Hi. I'm interested in mixing an adam optimizer and lbfgs optimizer. The below code conceptually shows what I want to do. from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad
def create_mlp(num_channels = []):
modules = []
for nc in num_channels:
modules.append(stax.Dense(nc))
modules.append(stax.Softplus)
modules.append(stax.Dense(1))
return stax.serial(*modules)
def main():
# 1. create a network
net_init_random, net_apply = create_mlp([10]*3)
rng = jax.random.PRNGKey(0)
in_shape = (-1, 2)
# 2. create a gradient decent optimizer
out_shape, net_params = net_init_random(rng, in_shape)
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
def loss(params, x, y):
return jnp.mean((net_apply(params, x) - y)**2)
def step(i, opt_state, x, y):
p = get_params(opt_state)
g = grad(loss)(p, x, y)
return opt_update(i, g, opt_state)
opt_state = opt_init(net_params)
# 3. optimize
for i in range(100):
x = np.random.random_sample((10,2))
y = np.random.random_sample((10,1))
step(i, opt_state, x, y)
# 4. lbfgs optimization
_x = np.random.random_sample((10,2))
_y = np.random.random_sample((10,1))
def func(params):
return loss(params, _x, _y)
net_params = get_params(opt_state)
results = tfp.optimizer.lbfgs_minimize(
func, initial_position=net_params, tolerance=1e-5)
if __name__ == "__main__":
main() Any kind of advice would be very helpful to me. |
Re: @nrontsis #1400 (comment).
I made a comparison of BFGS against scipy+numpy, scipy+jax, and pure jax on a benchmark problem, N-d least squares + L1 regularisation. @KeunwooPark I can't speak to using TFP in combination, but maybe #3847 is interesting to you. If you wanted to use pure JAX L-BFGS then you'll need to wait. I plan on implementing it mid-November, as well as pytree arguments. |
That's quite impressive! Can you share the script used to generate it? |
Just put it up here: https://gist.github.com/Joshuaalbert/214f14bbdd55d413693b8b413a384cae EDIT: the scipy+numpy and scipy+jitted(func) do numerical jacobians which requires many more function evaluations. Best to compare pure JAX and scipy+jitted(fund and grad) |
@Joshuaalbert Thank you for your comment! I guess I have to think more about a workaround or just use TensorFlow. Or use your L-BFGS implementation. According to your graph, mixing scipy and jax doesn't seem to be a good idea. I'm interested in implementing L-BFGS, but I'm really new to these concepts and still learning. |
@KeunwooPark If you're looking for something immediately, scipy+jax will probably do the job. Unless the speed is crucial. |
@KeunwooPark I modified your example to feed the neural network weights into L-BFGS. Unfortunately, from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad
def create_mlp(num_channels = []):
modules = []
for nc in num_channels:
modules.append(stax.Dense(nc))
modules.append(stax.Softplus)
modules.append(stax.Dense(1))
return stax.serial(*modules)
def main():
# 1. create a network
net_init_random, net_apply = create_mlp([10]*3)
rng = jax.random.PRNGKey(0)
in_shape = (-1, 2)
# 2. create a gradient decent optimizer
out_shape, net_params = net_init_random(rng, in_shape)
opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
def loss(params, x, y):
return jnp.mean((net_apply(params, x) - y)**2)
def step(i, opt_state, x, y):
p = get_params(opt_state)
g = grad(loss)(p, x, y)
return opt_update(i, g, opt_state)
opt_state = opt_init(net_params)
# 3. optimize
for i in range(100):
x = np.random.random_sample((10,2))
y = np.random.random_sample((10,1))
step(i, opt_state, x, y)
# 4. lbfgs optimization
_x = np.random.random_sample((10,2))
_y = np.random.random_sample((10,1))
net_params = get_params(opt_state)
def concat_params(params):
flat_params, params_tree = jax.tree_util.tree_flatten(params)
params_shape = [x.shape for x in flat_params]
return jnp.concatenate([x.reshape(-1) for x in flat_params]), (params_tree, params_shape)
param_vector, (params_tree, params_shape) = concat_params(net_params)
@jax.value_and_grad
def func(param_vector):
split_params = jnp.split(param_vector,
np.cumsum([np.prod(s) for s in params_shape[:-1]]))
flat_params = [x.reshape(s) for x, s in zip(split_params, params_shape)]
params = jax.tree_util.tree_unflatten(params_tree, flat_params)
return loss(params, _x, _y)
results = tfp.optimizer.lbfgs_minimize(
jax.jit(func), initial_position=param_vector, tolerance=1e-5)
if __name__ == "__main__":
main() Hope this is helpful! In the long run, better support for these structured inputs to L-BFGS in TFP will likely be the path forward. |
@sharadmv Wow. Thank you very much! This helps me a lot. |
Thanks for bringing BFGS support to JAX! Are there also any plans in bringing L-BFGS to JAX? |
Yes, there is a plan for L-BFGS. Q1 2021. |
If you need something now, you can use
tfp.substrates.jax.optimizers.lbfgs_minimize
…On Fri, Dec 18, 2020, 6:23 AM Joshua George Albert ***@***.***> wrote:
Yes, there is a plan for L-BFGS. Q1 2021.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#1400 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFJFSI24R6YQTQ427J46TBTSVM3UZANCNFSM4I227AGQ>
.
|
Hi, has there been any progress on the pure JAX implementation of LBFGS?? |
Hi, I have recently completed a jittable implementation of L-BFGS for my own purposes. PR-ing to jax is on my todo-list but keeps getting bumped down... A few design choices (@Joshuaalbert, @shoyer ?):
I can, of course, "downgrade" all of that but maybe they should be part of jax? I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look. |
Yes, please! We need this for neural nets.
Also sounds useful! I guess you just add a few complex conjugates? As long as this matches up with JAX's gradient convention for complex numbers (which I assume it does) this seems like a straightforward extension.
"log progress to console or to a file" sounds a little too opinionated to build into JAX itself, but this is absolutely very important! I would love to see what this looks like, and provide a "recipe" in the JAX docs so users can easily do this themselves.
I also would leave this part out for JAX. It's totally sensible, but I don't think it's worth adding a second API end-point for all optimizers.
The great thing about open source is that you don't need our permission to copy code :). As long as you keep the Apache license & copyright notice on it you are good to go! |
Agree with @shoyer. Your code can go in |
Ok, first draft is done, @Joshuaalbert let me know what you think. A couple of things to talk about:
PyTree support will be a seperate PR |
Thanks for the paper reference. I think you're right about BFGS doing something upon complex input. I'll think about the options and suggest something, in a small PR. Highest precision in line search is fine. Agreed to leaving inverse Hessian estimate as None. I looked over your code and it looks suitable for a PR. I need to go over the equations in more detail, but a first indication will be if it passes tests against scipy's L-BFGS. |
if you are looking at the complex formulas, note that i chose the jax convention for all gradient-types variables (g and y). This means that all |
Ho what a great thread ! |
Is there any interest in adding a quasi-Newton based optimizer? I was thinking of porting over:
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/bfgs.py
But wasn't sure if anyone else was interested or had something similar already.
The text was updated successfully, but these errors were encountered: