Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem differentiating through solver.run in OptaxSolver #31

Open
phinate opened this issue Sep 17, 2021 · 21 comments
Open

Problem differentiating through solver.run in OptaxSolver #31

phinate opened this issue Sep 17, 2021 · 21 comments
Labels
question Further information is requested

Comments

@phinate
Copy link
Contributor

phinate commented Sep 17, 2021

I've been trying to use OptaxSolver to perform a simple function minimization, since I want to differentiate through it's solution (the fixed point of the solver), but ran into an issue I'm not familiar with.

Here's a MWE for the error message:

import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax

def pipeline(param_for_grad, data):
    def to_minimize(latent):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(init_params = initial)

    return result

jax.value_and_grad(pipeline)(2., data=6.)

which yields this error:

CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.

My versions are:

jax==0.2.20
jaxlib==0.1.71
jaxopt==0.0.1
optax==0.0.9

Am I doing something very silly? I guess I'm also wondering if this example within the scope of the solver API? I noticed that this doesn't occur with solver.update, just with solver.run.

Thanks :)

@fllinares
Copy link
Collaborator

fllinares commented Sep 17, 2021

Hi,

I think only a couple of small changes would be needed.

To use implicit differentiation with solver.run, you should (1) expose the args with respect to which you'd like to differentiate the solver's solution explicitly in the signature of fun and (2) avoid using keyword arguments in the call to solver.run.

In your MWE:

def pipeline(param_for_grad, data):
    def to_minimize(latent, param_for_grad):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(initial, param_for_grad)

    return result

jax.value_and_grad(pipeline)(2., data=6.)

P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps.

@mblondel
Copy link
Collaborator

Thanks @phinate for the question and @fllinares for the answer!

Indeed, as Felipe explained, your param_for_grad was in the scope (this is what is meant by closed-over value) but it wasn't an explicit argument of run.

By the way, since run calls init for you, the line

initial, _ = solver.init(init_params = 5.)

is not needed. You can just set initial = 5 and then call run(initial, params_for_grad).

@mblondel
Copy link
Collaborator

We are working on a documentation, hopefully these things will become clearer soon.

@phinate
Copy link
Contributor Author

phinate commented Sep 18, 2021

Thanks both @mblondel & @fllinares for your replies, and the helpful information!

I'm struggling a little with this because the suggestion of moving param_for_grad into the to_minimize call explicitly is a bit cumbersome for my use case; the way I'm actually making this objective function looks more like:

def setup_objective(param_for_grad, **kwargs):
    to_minimize = complicated_function(param_for_grad, **kwargs)
    return to_minimize

def pipeline(param_for_grad, **kwargs):
    obj = setup_objective(param_for_grad, **kwargs)
    solver = OptaxSolver(fun=obj, opt=optax.adam(5e-2), implicit_diff=True)
    ... etc ...
    return result

To parametrize directly with param_for_grad would mean that I would have to construct the objective via complicated_function every time it was called in the minimization loop, when strictly this doesn't change with respect to param_for_grad during the minimization.

Am I missing something here in terms of nicely setting up this problem? Or for implicit diff, do I really need to be explicit in the way you described, even though param_for_grad is only used in the construction of the objective, and not for its evaluation?

Thanks again for the quick response earlier, and sorry if this is somehow unclear!

@mblondel
Copy link
Collaborator

We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don't. This is problematic if you have several big variables in your scope, such as data matrices.

@mblondel mblondel added the question Further information is requested label Sep 21, 2021
@mblondel
Copy link
Collaborator

By the way, you can also take a look at jax.lax.custom_root, which supports closed-over-variable. CC @shoyer

@lukasheinrich
Copy link

thanks @mblondel - is this something that could be in scope for jaxpot later on to add closed-over variables? We can certainly provide metadata which variables require diffing and which don't

@mblondel
Copy link
Collaborator

Could you sketch how this would look like on the user side?

@lukasheinrich
Copy link

lukasheinrich commented Sep 21, 2021

could we co-opt the static_args-like API for this.?

Edit: I guess this is equivalent to the argnums kwarg .. so would that be sufficient?

def pipeline(param_for_grad, data):
    def to_minimize(latent, param_for_grad):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(initial, param_for_grad)

    return result

pipeline = jaxopt.annotate(pipeline, diff_args = (0,))

jax.value_and_grad(pipeline)(2., data=6.)

@shoyer
Copy link
Member

shoyer commented Sep 21, 2021

We could support handling closed-over-variables via jax.closure_convert (like jax.lax.custom_root), but the tradeoff is that it requires tracing Python functions to a JAXpr. This means you can't do dynamic control flow/debugging with Python.

The ideal solution would probably be either to (1) encourage users to use closure_convert themselves, or (2) possibly add an optional argument to allow for opting into automatic closure conversion, e.g., closure_convert=True.

@lukasheinrich
Copy link

thanks @shoyer - can we use clcosure_convert now to achieve the desired behavior?

@phinate
Copy link
Contributor Author

phinate commented Oct 10, 2021

Hi again all -- we followed @shoyer's suggestion of using jax.closure_convert using a function pretty much ripped from the jax docs:

def _minimize(objective_fn, lhood_pars, lr):
        converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars) 
        # aux_pars seems to be empty, took that line from docs example
        solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
        return solver.run(lhood_pars, *aux_pars)[0]

where objective_fn is usually created on-the-fly with the aforementioned setup_objective function.

Using this does allow autodiff with no errors, but we encounter a pretty substantial slowdown compared to wrapping an equivalent Adam optimiser using this more explicit implementation of the two-phase method. We were hoping to transition away from this in the interest of keeping up with jax releases and other software that also follows jax, as well as the far more active effort in jaxopt.

As a side note, we do have an additional performance bottleneck coming from external software constraints that is hard to decorrelate from the change to jaxopt (lack of ability to JIT some parts of the pipeline due to changing jax version), but based on comparisons removing the JIT from the old program, I don't think it's nearly enough to explain the ~10x slowdown.

Is there any expected drop in performance from using closure_convert, perhaps given my previous statements on the complexity of the setup_objective function?

@shoyer
Copy link
Member

shoyer commented Oct 11, 2021 via email

@mblondel
Copy link
Collaborator

mblondel commented Oct 11, 2021

I'm not sure what your setup_objective function is doing but I would try to decompose it

    def pipeline(param_for_grad, latent):
        res = intermediary_step(param_for_grad, **kwargs)

        def objective_fun(params, intermediary_result):
            [...]   # do not use param_for_grad or res here!

        solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
        return solver.run(init_params, intermediary_result=res * latent).params

    jax.jacobian(pipeline)(param_for_grad, latent)

The key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea.

@mblondel
Copy link
Collaborator

Any follow up on this? Does your objective seem decomposable in the way I describe?

@phinate
Copy link
Contributor Author

phinate commented Nov 3, 2021

Any follow up on this? Does your objective seem decomposable in the way I describe?

Have just thought about this a bit -- I'm not 100% if this would work, but one potential resolution to this for us in terms of decomposing the problem could be to build our statistical model (expensive boilerplate) from which we want to call a logpdf method, and then construct the objective like this

def pipeline(param_for_grad):
    res = model(param_for_grad)

    def objective_fun(params, model):
      return model.logpdf(params)
        

    solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
    return solver.run(init_params, model=model).params

jax.jacobian(pipeline)(param_for_grad)

Provided this model was registered as a pytree, do you think this would resolve the problem? It's not something we've implemented, but could be if this would work.

@phinate
Copy link
Contributor Author

phinate commented Nov 29, 2021

Just an update on this: we've managed to get the jit working on our side with closure_convert, and we see the performance recover, so @shoyer got it right on that count despite my (incorrect) assumption -- thanks!

If it's helpful, I'd be happy to summarise this thread as a small entry into the documentation via a PR @mblondel, since it could come up again for other users with similar use cases.

Thanks both for the fast and attentive help!

@mblondel
Copy link
Collaborator

What would the code snippet look like?

@phinate
Copy link
Contributor Author

phinate commented Apr 6, 2023

hi again @mblondel, sorry to resurrect this from the dead -- my solution that uses closure_convert randomly started leaking tracers, with one of the jax/jaxlib updates and it's a bit of a nightmare to debug. Luckily, I found a fairly simple MWE:

from functools import partial

import jax
import jax.numpy as jnp
import jaxopt
import optax


# dummy model for test purposes
class Model:
    x: jax.Array
    def __init__(self, x) -> None:
        self.x = x
    def logpdf(self, pars, data):
        return jnp.sum(pars*data*self.x)

@partial(jax.jit, static_argnames=["objective_fn"])
def _minimize(
    objective_fn,
    init_pars,
    lr,
):
    # this is the line added from our discussion above
    converted_fn, aux_pars = jax.closure_convert(objective_fn, init_pars)
    # aux_pars seems to be empty -- would have assumed it was the closed-over vals or similar?
    solver = jaxopt.OptaxSolver(
        fun=converted_fn, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
    )
    return solver.run(init_pars, *aux_pars)[0]


@partial(jax.jit, static_argnames=["model"])
def fit(
    data,
    model,
    init_pars,
    lr = 1e-3,
):
    def fit_objective(pars):
        return -model.logpdf(pars, data)

    fit_res = _minimize(fit_objective, init_pars, lr)
    return fit_res

def pipeline(x):
    model = Model(x)
    mle_pars = fit(
        model=model,
        data=jnp.array([5.0, 5.0]),
        init_pars=jnp.array([1.0, 1.1]),
        lr=1e-3,
    )
    return mle_pars

jax.jacrev(pipeline)(jnp.asarray(0.5))
# >> JaxStackTraceBeforeTransformation

(this is jaxopt==0.6)

Another thing to note: the jaxpr tracing induced by closure_convert seems to really fill up the cache, which made this quite problematic in practice (I had to use @patrick-kidger's hack from this JAX issue). Just a health warning for anyone else interested in this type of solution!

I can't see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid closure_convert altogether).

@phinate
Copy link
Contributor Author

phinate commented Apr 14, 2023

I explored this a bit, and my particular workflow here was made possible if one makes the Model class a Pytree, which allows me to feed in the model as an explicit argument to the objective function while keeping jit across the optimization procedure. I think this also means that the relevant parameters for grad are no longer closed over, since the Pytree contains that information.

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from simple_pytree import Pytree

# dummy model for test purposes
class Model(Pytree):
    x: jax.Array
    def __init__(self, x) -> None:
        self.x = x
    def logpdf(self, pars, data):
        return jsp.stats.norm.logpdf(data, loc=pars*self.x, scale=1.0).sum()


@jax.jit
def pipeline(param_for_grad):
    data=jnp.array([5.0, 5.0])
    init_pars=jnp.array([1.0, 1.1])
    lr=1e-3

    model = Model(param_for_grad)

    def fit(pars, model, data):
        def fit_objective(pars, model, data):
            return -model.logpdf(pars, data)

        solver = jaxopt.OptaxSolver(
            fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
        )
        return solver.run(init_pars, model=model, data=data)[0]

    return fit(init_pars, model, data)

jax.jacrev(pipeline)(jnp.asarray(0.5))
# > Array([-1.33830826e+01,  7.10542736e-15], dtype=float64, weak_type=True)

Don't know if there's another potential issue above that i'm smearing over with this approach, but it works without closure_convert! It may be hard to coerce a complicated model into a Pytree, but that's possibly something for us to worry more about.

@patrick-kidger
Copy link

I think you're doing the right thing by making the model a PyTree, i.e. I don't think you're smearing over any issue.

This is the same approach Equinox uses ubiquitously, and this handles all the complexity of Diffrax just fine!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

6 participants