Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BFGS/Quasi-Newton optimizers? #1400

Open
proteneer opened this issue Sep 26, 2019 · 61 comments
Open

BFGS/Quasi-Newton optimizers? #1400

proteneer opened this issue Sep 26, 2019 · 61 comments

Comments

@proteneer
Copy link
Contributor

Is there any interest in adding a quasi-Newton based optimizer? I was thinking of porting over:

https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/bfgs.py

But wasn't sure if anyone else was interested or had something similar already.

@shoyer
Copy link
Collaborator

shoyer commented Sep 26, 2019

Yes, I would be really excited about this! In a non-stochastic setting, these are usually much more effective than stochastic gradient descent.

We could even define gradients for these optimizers using implicit root finding (I.e., based on #1339)

@shoyer
Copy link
Collaborator

shoyer commented Sep 26, 2019

L-BFGS is particular would be really nice to have: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/optimizer/lbfgs.py

@fehiepsi
Copy link
Contributor

fehiepsi commented Sep 26, 2019

This would be great! I am interested in and happy to review the math of your PR. (another reference is pytorch lbfgs which is mostly based on https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html).

@proteneer
Copy link
Contributor Author

Whoa - differentiable root finding is pretty whack! The PyTorch L-BFGS seems significantly more advanced than the tfp one - let me read them over to see if we have everything that's needed.

@fehiepsi
Copy link
Contributor

fehiepsi commented Sep 27, 2019

I think the difference between two versions is: pytorch uses strong Wolfe line search (not more advanced) while tensorflow uses Hager Zhang line search. I would propose to stick with Hager Zhang line search if you are familiar with tensorflow code. :)

@shoyer
Copy link
Collaborator

shoyer commented Sep 27, 2019 via email

@fehiepsi
Copy link
Contributor

fehiepsi commented Sep 27, 2019

When I used the strong Wolfe line search to do bayesian optimization for some toy datasets, I faced instability issues (hessian blow up -> nan happens) near the optimum value. So I hope that Hager Zhang line search will do better. Looking at the documentation of Hager Zhang, it seems to deal with the issue I faced: "On a finite precision machine, the exact Wolfe conditions can be difficult to satisfy when one is very close to the minimum and as argued by [Hager and Zhang (2005)][1], one can only expect the minimum to be determined within square root of machine precision".

@TuanNguyen27
Copy link
Contributor

@proteneer Just wondering if you are still working on this! I'm also interested in attempting an implementation.

@proteneer
Copy link
Contributor Author

Hi @TuanNguyen27 - sorry this fell off the radar. I ended up implementing the FIRE optimizer instead based on @sschoenholz 's code. Feel free to take a stab yourself, as I'm not working on this right now.

@TuanNguyen27
Copy link
Contributor

@shoyer @fehiepsi I've been staring at pytorch's implementation, but it contains too many parameters (both initial and auxiliary) to pack into state variable if i'm trying to follow init_fun, update_fun, get_params template defined in optimizers.py. Is it something I need to strictly follow, or do you have any design advice / suggestion that could make this more approachable? :D

@mattjj
Copy link
Collaborator

mattjj commented Nov 16, 2019

optimizers.py is specialized to stochastic first-order optimizers; for second-order optimizers, don't feel constrained to follow its APIs. IMO just do what makes sense to you (while hopefully following the spirit of being functional and minimal).

@fehiepsi
Copy link
Contributor

fehiepsi commented Nov 17, 2019

@TuanNguyen27 PyTorch LBFGS is implemented to have a similar interface as other stochastic optimizers. In practice, I only use 1 "step" of PyTorch LBFGS to find the solution, so I guess you don't need to follow other JAX optimizers, and no need to manage state variable except for internal while loops. In my view, the initial API in JAX would be close to scipy lbfgs-b, that is:

def lbfgs(fun, x0, maxtol=..., gtol=..., maxfun=..., maxiter=...):
    return solution

or similar to build_ode

def lbfgs(fun, maxtol=..., gtol=..., maxfun=..., maxiter=...):
    return optimize_fun

solution = lbfgs(f)(x0)

. But as Matt said, just do what makes sense to you.
FYI, wikipedia has a nice list of test functions to verify your implementation.

@shoyer
Copy link
Collaborator

shoyer commented Nov 17, 2019 via email

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Feb 17, 2020

Any news on this?

I wrote an adaption of pytorchs L-BFGS implementation for my personal use (even supporting complex arguments) that works well for me.
I am happy to share it, but I am new at contributing to large projects and would need some pointers, where to start, how to write tests, etc (@shoyer ?)

@joglekara
Copy link
Contributor

I'm also more than happy to help here @shoyer and @Jakob-Unfried

@shoyer
Copy link
Collaborator

shoyer commented Feb 18, 2020

It would be great for someone to start on this! Ideally it would be broken up into a number of smaller PRs, e.g., starting with line search, then adding L-BFGS, then making it differentiable via the implicit function theorem.

I think the TF-probability implementation is probably the best place to start for a JAX port. The TF control flow ops are all functional and have close equivalents in JAX. The pytorch version looks fine but it’s very object oriented/stateful, which could make it harder to integrate into JAX.

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Feb 18, 2020

I have an implementation of L-BFGS with Strong Wolfe line search lying around. Just needs a few touch ups to make it nicer as well as documentation that people who are not me will understand.

I will get it ready for someone to review, probably late this evening (european time)

A couple of questions:

  1. There are quite a lot of parameters (11 if i didnt miscount) in the algorithm and in most cases people will just use the defaults. My idea would be to make them **kwargs. Thoughts?

  2. I wanted to enable the cost_function to take an arbitrary pytree as input. Since they are a finite number of independent variables, I think of them as a single column vector for the maths of optimisation. During the algorithm, I need to perform vector operations (addition, scalar product, etc) on these "vectors". I implemented a lot of helper functions, like e.g. _vector_add(x1, x2) which essentially just use tree utils to perform the necessary steps. in this example return tree_multimap(lambda arr1, arr2: arr1 + arr2, x1, x2). I use a total of 12 such functions, so it might unnecessarily clutter the code.
    The alternative I see is to implement the algorithm for cost_functions witch can only take a single 1D array as input and use a decorator to enable arbitrary pytrees. Thoughts=

  3. Where do I write tests? Is there a conventional format to follow? Maybe a good example to look at?

  4. do you prefer the semantics x_optim = lbfgs(fun, x0) or x_optim = lbfgs(fun)(x0) ?
    The latter would probably make more sense if we want to think about it as implicitly defining a new function. But in the end, the latter can just be a thin wrapper around the former.

@shoyer
Copy link
Collaborator

shoyer commented Feb 18, 2020

  1. To the extent it makes sense, I would suggest sticking to the general interface of scipy.optimize.minimize. This suggests stuffing everything into options. We can drop arguments that don't really make sense for JAX, e.g., could use jax.grad or jax.value_and_grad internally rather than requiring that it be passed explicitly with jac=True.

  2. Handling pytrees sounds great! Helper functions are fine, though if you can group a series of operations and only applying tree_multimap once that's even better. Note that you can use it as a decorator with curry:

In [6]: from jax.util import curry

In [7]: from jax.tree_util import tree_multimap

In [8]: @curry(tree_multimap)
   ...: def f(x, y):
   ...:     return x + y
   ...:

In [9]: f({'a': 1}, {'a': 2})
Out[9]: {'a': 3}
  1. Take a look at the existing tests in tests/. We run tests with pytest with classes inherited from JaxTestCase which handles parameterization. vectorize_test.py might be a good minimize example to look at.

  2. I would lean towards the former, which matches SciPy.

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Feb 19, 2020

@joglekara
@shoyer

Let me know what you think about my implementation: https://github.com/Jakob-Unfried/jax/commit/b28300cbcc234cda0b40b16cf15678e3f78e4085

Open Questions:

  • Think about default options
  • More elaborate tests

@shoyer
Copy link
Collaborator

shoyer commented Feb 19, 2020

@Jakob-Unfried it looks like a great start, but ideally we would use JAX's functional control flow so it is compatible with jit. There is lots of scalar arithmetic and book-keeping that XLA could likely accelerate nicely. Note: the implementation doesn't need to be differentiable, since we can define the gradient pass as a separate root finding problem via lax.custom_root.

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Feb 20, 2020

Unfortunately, I am not (yet) very familiar with the requirements of jit and how to write code that it can process efficiently.

I see two paths forward:

  • Someone gives me pointers, what needs to be done
    (e.g. "Instead of for loops i shoud use jax.lax.fori_loop", etc.)
    With these, two main points come up, which i am not sure about:

    • How do you "break" these? should i rather use a while_loop and have cond_fun do the checks that could break the loop?
    • There are many variables that change from loop iteration to loop iteration. and they are potentially rather larger (e.g. the y_history, s_historyand rho_history). Would I just pack all of these into val (the (2nd) argument of body_fun)?
  • Someone takes my implementation and improves it. You don't necessarily need to understand the details of the implementation.

@david-waterworth
Copy link

I'd also love to see an implementation of Levenberg-Marquardt. The Matlab ML toolkit includes an implementation and it's regularly used to train for timeseries (regression) RNN (NARX) models - it seems to perform much better than gradient descent. I'd love to be able to validate this and compare with other 2nd order solvers.

I'll watch the progress of this thread with interest and try and contribute where I can and if/when I feel I understand enough I might give it a go.

@shoyer
Copy link
Collaborator

shoyer commented Feb 20, 2020

@Jakob-Unfried Let me suggest an iterative process:

  1. Check in a working version of your current code (with tests) without changing the control flow form Python.
  2. Convert functions one by one into a form suitable for use with jit, starting with the inner-most layers and working outwards. We can verify that this works by adding explicit @jit annotations.
  3. Make the whole thing differentiable, via the implicit theorem.

This can and should happen over multiple PRs. As long you or others express willingness to help, I am happy to start on the process.

As for "how to write code that works well with jit", the general guidance is:

  1. Convert loops/recursion into lax.while_loop. Yes, this typically means you need to pack a lot of extra state into the val.
  2. Convert if statements into np.where (usually it isn't worth using lax.cond).

@TuanNguyen27
Copy link
Contributor

Also happy to help !

@Jakob-Unfried
Copy link
Contributor

@shoyer
Sounds good!
I will write (and run) a full set of tests and PR the version with pure-python control-flow later today.

I will then get back here with a TODO list.

@TuanNguyen27
Let's chat on Gitter?

@TuanNguyen27
Copy link
Contributor

@Jakob-Unfried sounds good, I just sent you a gitter DM to follow up

@HDembinski
Copy link

What is the benefit of adding a minimizer? The minimizer calls the cost function and its gradient. Both can be JAX accelerated already and the functions dominate the computing time. I don't see how one could gain here. The power of JAX is that you can mix with other libraries.

@shoyer
Copy link
Collaborator

shoyer commented Feb 22, 2020 via email

@shoyer
Copy link
Collaborator

shoyer commented Jun 23, 2020

The main reason to rewrite algorithms in JAX is to support JAX's function transformations. In this case, that would mostly be relevant for performance:

  • We can run optimizations entirely on GPU/TPU without bringing data back to the CPU (which can be quite expensive/slow).
  • For cheap objective functions, Python's overhead is a concern. It can be much more efficient to use JAX's jit to compile over the full optimization process.
  • For many optimizations, it might make sense to vmap an optimizers to allow for vectorization.
  • Unnecessary data copies are often a concern, both for performance and memory constraint reasons. SciPy's L-BFGS-B is written in Fortran and requires parameters in the form of float64 vectors. In contrast, with JIT compilation we will (eventually) be able to support optimization over arbitrarily nested arrays and with any shape/dtype, e.g., with tree_vectorize from WIP: tree vectorizing transformation #3263.

@nrontsis
Copy link

I see, thank you for the explanation!

@brianwa84
Copy link
Contributor

TFP-on-JAX now supports L-BFGS:

!pip install -q tfp-nightly[jax]
from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp

def rosenbrock(coord):
  """The Rosenbrock function in two dimensions with a=1, b=100.

  Args:
    coord: Array with shape [2]. The coordinate of the point to evaluate
      the function at.

  Returns:
    fv: A scalar tensor containing the value of the Rosenbrock function at
      the supplied point.
    dcoord: Array with shape [2]. The derivative of the function with respect to 
      `coord`.
  """
  x, y = coord[0], coord[1]
  fv = (1 - x)**2 + 100 * (y - x**2)**2
  dfx = 2 * (x - 1) + 400 * x * (x**2 - y)
  dfy = 200 * (y - x**2)
  return fv, jnp.stack([dfx, dfy])

start = jnp.float32([-1.2, 1.0])
results = tfp.optimizer.lbfgs_minimize(
    rosenbrock, initial_position=start, tolerance=1e-5)

results.position  # => DeviceArray([1.0000001, 1.0000002], dtype=float32)

@brianwa84
Copy link
Contributor

Also, in tomorrow's nightly build, BFGS and Nelder-Mead optimizers will be supported in TFP JAX.

shoyer added a commit that referenced this issue Jul 29, 2020
* BFGS algorithm
Addressing #1400

* * addresses @shoyer comments of PR

* * skip dtype checks

* * backslash in docstring

* * increase closeness tol

* * increase closeness atol to 1.6e-6

* * addresses jakevdp comments

* * same line search as scipy
* same results format
* same (and more) testing as in scipy for line search and bfgs
* 2 spacing
* documenting
* analytic hessian non default but still available
* NamedTuple classes

* * small fix in setup_method

* * small doc string addition

* * increase atol to 2e-5 for comparison

* * removed experimental analytic_hessian
* using jnp.where for all binary replace operations
* removed _nojit as this is what disable_jit does

* * fix indentation mangling
* remove remaining _nojit

* * fixing more indentation mangling

* * segregate third_party test

* * use parametrise

* * use parametrise

* * minor nitpicking

* * fix some errors

* * use _CompileAndCheck

* * replace f_0 and g_0 for (ugly) scipy variable names

* * remove unused function

* * fix spacing

* * add args argument to minimize
* adhere fmin_bfgs to scipy api

* * remove unused function

* * ignore F401

* * look into unittest

* * fix unittest error

* * delete unused function
* more adherence to scipy's api
* add scipy's old_old_fval arg though unused
* increase line_search default maxiter to 20 (10 not enough in some cases)

* * remove unused imports

* * add ord=norm to the initial convergence check

* * remove helper function

* * merge jax/master

* * Resolve a remnant conflict from merging master to solve ReadTheDocs issue.

* * Add an informative termination message and status number.

* Revert changes to unrelated files

* cleanup bfgs_minimize

* cleanup minimize.py

* Move minimize_bfgs.py to _bfgs.py

* Move more modules around

* improve docs

* high precision einsum

* Formatting in line search

* fixup

* Type checking

* fix mypy failures

* minor fixup

Co-authored-by: Stephan Hoyer <[email protected]>
@shoyer
Copy link
Collaborator

shoyer commented Jul 30, 2020

Now that BFGS support has been merged and provides some general scaffolding (see #3101), we'd love to get L-BFGS (and other optimizers) in JAX proper, as well as support for pytrees and implicit differentiation.

@Joshuaalbert
Copy link
Contributor

Joshuaalbert commented Jul 30, 2020 via email

@KeunwooPark
Copy link

Hi. I'm interested in mixing an adam optimizer and lbfgs optimizer.
Jax provides an adam optimizer, so I used that. But I don't understand how I can turn the network parameters from Jax's adam optimizer to the input of tfp.optimizer.lbfgs_minimize().

The below code conceptually shows what I want to do.
The code tries to optimize a network with adam first, and then use lbfgs.

from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad

def create_mlp(num_channels = []):
    modules = []
    for nc in num_channels:
        modules.append(stax.Dense(nc))
        modules.append(stax.Softplus)

    modules.append(stax.Dense(1))
    return stax.serial(*modules)

def main():
    # 1. create a network
    net_init_random, net_apply = create_mlp([10]*3)
    rng = jax.random.PRNGKey(0)
    in_shape = (-1, 2)

    # 2. create a gradient decent optimizer
    out_shape, net_params = net_init_random(rng, in_shape)
    opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)

    def loss(params, x, y):
        return jnp.mean((net_apply(params, x) - y)**2)

    def step(i, opt_state, x, y):
        p = get_params(opt_state)
        g = grad(loss)(p, x, y)
        return opt_update(i, g, opt_state)

    opt_state = opt_init(net_params)

    # 3. optimize
    for i in range(100):
        x = np.random.random_sample((10,2))
        y = np.random.random_sample((10,1))
        step(i, opt_state, x, y)

    # 4. lbfgs optimization
    _x = np.random.random_sample((10,2))
    _y = np.random.random_sample((10,1))

    def func(params):
        return loss(params, _x, _y)

    net_params = get_params(opt_state)
    results = tfp.optimizer.lbfgs_minimize(
        func, initial_position=net_params, tolerance=1e-5)

if __name__ == "__main__":
    main()

Any kind of advice would be very helpful to me.
@brianwa84 Could you provide an example that mixes the two kinds of optimizers?

@Joshuaalbert
Copy link
Contributor

Re: @nrontsis #1400 (comment).

I am confused about the benefits of implementing BFGS or any other non-trivial optimisation solvers in JAX.

I made a comparison of BFGS against scipy+numpy, scipy+jax, and pure jax on a benchmark problem, N-d least squares + L1 regularisation.
Screenshot from 2020-10-05 16-47-31

@KeunwooPark I can't speak to using TFP in combination, but maybe #3847 is interesting to you. If you wanted to use pure JAX L-BFGS then you'll need to wait. I plan on implementing it mid-November, as well as pytree arguments.

@nrontsis
Copy link

That's quite impressive! Can you share the script used to generate it?

@Joshuaalbert
Copy link
Contributor

Joshuaalbert commented Oct 16, 2020

Just put it up here: https://gist.github.com/Joshuaalbert/214f14bbdd55d413693b8b413a384cae

EDIT: the scipy+numpy and scipy+jitted(func) do numerical jacobians which requires many more function evaluations. Best to compare pure JAX and scipy+jitted(fund and grad)

@KeunwooPark
Copy link

@Joshuaalbert Thank you for your comment! I guess I have to think more about a workaround or just use TensorFlow. Or use your L-BFGS implementation. According to your graph, mixing scipy and jax doesn't seem to be a good idea. I'm interested in implementing L-BFGS, but I'm really new to these concepts and still learning.

@Joshuaalbert
Copy link
Contributor

@KeunwooPark If you're looking for something immediately, scipy+jax will probably do the job. Unless the speed is crucial.

@sharadmv
Copy link
Collaborator

@KeunwooPark I modified your example to feed the neural network weights into L-BFGS. Unfortunately, tfp.optimize.lbfgs_minimize does not support optimizing over structures of tensors/arrays, but you can concatenate the network weights together, and then split them back up in the loss function.

from tensorflow_probability.substrates import jax as tfp
from jax import numpy as jnp
import numpy as np
from jax.experimental import stax, optimizers
import jax
from jax import grad

def create_mlp(num_channels = []):
    modules = []
    for nc in num_channels:
        modules.append(stax.Dense(nc))
        modules.append(stax.Softplus)

    modules.append(stax.Dense(1))
    return stax.serial(*modules)

def main():
    # 1. create a network
    net_init_random, net_apply = create_mlp([10]*3)
    rng = jax.random.PRNGKey(0)
    in_shape = (-1, 2)

    # 2. create a gradient decent optimizer
    out_shape, net_params = net_init_random(rng, in_shape)
    opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)

    def loss(params, x, y):
        return jnp.mean((net_apply(params, x) - y)**2)

    def step(i, opt_state, x, y):
        p = get_params(opt_state)
        g = grad(loss)(p, x, y)
        return opt_update(i, g, opt_state)

    opt_state = opt_init(net_params)

    # 3. optimize
    for i in range(100):
        x = np.random.random_sample((10,2))
        y = np.random.random_sample((10,1))
        step(i, opt_state, x, y)

    # 4. lbfgs optimization
    _x = np.random.random_sample((10,2))
    _y = np.random.random_sample((10,1))

    net_params = get_params(opt_state)

    def concat_params(params):
        flat_params, params_tree = jax.tree_util.tree_flatten(params)
        params_shape = [x.shape for x in flat_params]
        return jnp.concatenate([x.reshape(-1) for x in flat_params]), (params_tree, params_shape)

    param_vector, (params_tree, params_shape) = concat_params(net_params)

    @jax.value_and_grad
    def func(param_vector):
        split_params = jnp.split(param_vector,
                np.cumsum([np.prod(s) for s in params_shape[:-1]]))
        flat_params = [x.reshape(s) for x, s in zip(split_params, params_shape)]
        params = jax.tree_util.tree_unflatten(params_tree, flat_params)
        return loss(params, _x, _y)

    results = tfp.optimizer.lbfgs_minimize(
        jax.jit(func), initial_position=param_vector, tolerance=1e-5)

if __name__ == "__main__":
    main()

Hope this is helpful! In the long run, better support for these structured inputs to L-BFGS in TFP will likely be the path forward.

@KeunwooPark
Copy link

@sharadmv Wow. Thank you very much! This helps me a lot.

@tetterl
Copy link

tetterl commented Dec 17, 2020

Thanks for bringing BFGS support to JAX! Are there also any plans in bringing L-BFGS to JAX?

@Joshuaalbert
Copy link
Contributor

Yes, there is a plan for L-BFGS. Q1 2021.

@brianwa84
Copy link
Contributor

brianwa84 commented Dec 18, 2020 via email

@salcats
Copy link

salcats commented Mar 11, 2021

Hi, has there been any progress on the pure JAX implementation of LBFGS??

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Mar 11, 2021

Hi, I have recently completed a jittable implementation of L-BFGS for my own purposes.

PR-ing to jax is on my todo-list but keeps getting bumped down...
Since there is continued interest, i will prioritise it.
Probably this weekend, no promises though.

A few design choices (@Joshuaalbert, @shoyer ?):
I included the following features, because they are convenient for my purposes, but dont match with the current interface of jax.scipy.optimise:

  • The optimisation parameters (inputs to the function to be optimised) can by arbitrary pytrees
  • The optimisation parameters can be complex
  • I have an option to log progress to console or to a file in real time using jax.experimental.host_callback (this is because my jobs are regularly killed)
  • in addition to minimise I have maximise, which just feeds the negative of the cost function to minimise and readjusts this sign in the logged data and output

I can, of course, "downgrade" all of that but maybe they should be part of jax?

I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look.

@shoyer
Copy link
Collaborator

shoyer commented Mar 11, 2021

  • The optimisation parameters (inputs to the function to be optimised) can by arbitrary pytrees

Yes, please! We need this for neural nets.

  • The optimisation parameters can be complex

Also sounds useful! I guess you just add a few complex conjugates? As long as this matches up with JAX's gradient convention for complex numbers (which I assume it does) this seems like a straightforward extension.

  • I have an option to log progress to console or to a file in real time using jax.experimental.host_callback (this is because my jobs are regularly killed)

"log progress to console or to a file" sounds a little too opinionated to build into JAX itself, but this is absolutely very important! I would love to see what this looks like, and provide a "recipe" in the JAX docs so users can easily do this themselves.

  • in addition to minimise I have maximise, which just feeds the negative of the cost function to minimise and readjusts this sign in the logged data and output

I also would leave this part out for JAX. It's totally sensible, but I don't think it's worth adding a second API end-point for all optimizers.

I cant publish the code without dealing with license issues first (i copied parts from the BFGS thats already in jax), but i can share it privately if someone wants to have a look.

The great thing about open source is that you don't need our permission to copy code :). As long as you keep the Apache license & copyright notice on it you are good to go!

@Joshuaalbert
Copy link
Contributor

Agree with @shoyer. Your code can go in jax/_src/scipy/optimize/lbfgs.py and tests in, tests/third_party/scipy assuming you're using tests from scipy. Also, you should add the method to _src/scipy/optimize/minimize.py. I'll be happy to review this. I was off-and-on working on my own L-BFGS, but love that @Jakob-Unfried has got there first.

@Jakob-Unfried
Copy link
Contributor

Ok, first draft is done, @Joshuaalbert let me know what you think.

A couple of things to talk about:

  • The behavior with complex inputs should be documented, depending on what we end up doing:
  • L-BFGS can handle complex numbers, but BFGS can not. And there is no good way to implement complex BFGS, essentially because second complex derivatives have "4 real degrees of freedom" which can not be stored in a single complex number.
    (more details are discussed in this paper, unfortunately paywalled...)
    but BFGS should do something when given complex inputs. it could
    • raise an Error
    • optimise a function of twice as many real values,
      e.g. optimise lambda x: original_fun(x[:dim] + 1.j * x[dim:]) using the existing implementation
  • i also used dot with highest precision in line_search, hope that is ok?
  • I do not see a point of extracting an inverse hessian approximation from L-BFGS, so i just set it to None.
    Scipy constructs a LinearOperator (see here) but i think that would be overkill right? In particular, constructing an entire matrix would defeat the whole point of using L-BFGS (which is why scipy does not do that).

PyTree support will be a seperate PR

@Joshuaalbert
Copy link
Contributor

Thanks for the paper reference. I think you're right about BFGS doing something upon complex input. I'll think about the options and suggest something, in a small PR. Highest precision in line search is fine. Agreed to leaving inverse Hessian estimate as None. I looked over your code and it looks suitable for a PR. I need to go over the equations in more detail, but a first indication will be if it passes tests against scipy's L-BFGS.

@Jakob-Unfried
Copy link
Contributor

if you are looking at the complex formulas, note that i chose the jax convention for all gradient-types variables (g and y).

This means that all gs and ys differ from the sorber paper by a complex conjugation

@jecampagne
Copy link
Contributor

Ho what a great thread !
I was playing with jaxopt.SciBoundedMinimize with "L-BFGS-B" method and faced some problem as this scipy-wraper deals with jnp_to_onp conversion. So, reading this thread may be is the solution, where to find the how-to-do with jax 0.3.10? Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests