-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem differentiating through solver.run
in OptaxSolver
#31
Comments
Hi, I think only a couple of small changes would be needed. To use implicit differentiation with In your MWE:
P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps. |
Thanks @phinate for the question and @fllinares for the answer! Indeed, as Felipe explained, your By the way, since
is not needed. You can just set |
We are working on a documentation, hopefully these things will become clearer soon. |
Thanks both @mblondel & @fllinares for your replies, and the helpful information! I'm struggling a little with this because the suggestion of moving def setup_objective(param_for_grad, **kwargs):
to_minimize = complicated_function(param_for_grad, **kwargs)
return to_minimize
def pipeline(param_for_grad, **kwargs):
obj = setup_objective(param_for_grad, **kwargs)
solver = OptaxSolver(fun=obj, opt=optax.adam(5e-2), implicit_diff=True)
... etc ...
return result To parametrize directly with Am I missing something here in terms of nicely setting up this problem? Or for implicit diff, do I really need to be explicit in the way you described, even though Thanks again for the quick response earlier, and sorry if this is somehow unclear! |
We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don't. This is problematic if you have several big variables in your scope, such as data matrices. |
By the way, you can also take a look at |
thanks @mblondel - is this something that could be in scope for |
Could you sketch how this would look like on the user side? |
could we co-opt the Edit: I guess this is equivalent to the
|
We could support handling closed-over-variables via The ideal solution would probably be either to (1) encourage users to use |
thanks @shoyer - can we use |
Hi again all -- we followed @shoyer's suggestion of using def _minimize(objective_fn, lhood_pars, lr):
converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars)
# aux_pars seems to be empty, took that line from docs example
solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
return solver.run(lhood_pars, *aux_pars)[0] where Using this does allow autodiff with no errors, but we encounter a pretty substantial slowdown compared to wrapping an equivalent Adam optimiser using this more explicit implementation of the two-phase method. We were hoping to transition away from this in the interest of keeping up with As a side note, we do have an additional performance bottleneck coming from external software constraints that is hard to decorrelate from the change to Is there any expected drop in performance from using |
Closure conversion relies on JAX’s jaxpr interpreter, which is much slower
than Python’s interpreter. If you can’t JIT the entire thing, that probably
explains the performance slow down.
…On Sun, Oct 10, 2021 at 4:58 AM Nathan Simpson ***@***.***> wrote:
Hi again all -- we followed @shoyer <https://github.com/shoyer>'s
suggestion of using jax.closure_convert using a function pretty much
ripped from the jax docs:
def _minimize(objective_fn, lhood_pars, lr):
converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars)
# aux_pars seems to be empty, took that line from docs example
solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
return solver.run(lhood_pars, *aux_pars)[0]
where objective_fn is usually created on-the-fly with the aforementioned
setup_objective function.
Using this does allow autodiff with no errors, but we encounter a pretty
substantial slowdown compared to wrapping an equivalent Adam optimiser
using this more explicit implementation of the two-phase method
<https://github.com/gehring/fax/blob/1bc0f4ddf4f0d54370b0c04282de49aa85685791/fax/implicit/twophase.py#L79>.
We were hoping to transition away from this in the interest of keeping up
with jax releases and other software that also follows jax, as well as
the far more active effort in jaxopt.
As a side note, we do have an additional performance bottleneck coming
from external software constraints that is hard to decorrelate from the
change to jaxopt (lack of ability to JIT some parts of the pipeline due
to changing jax version), but based on comparisons removing the JIT from
the old program, I don't think it's nearly enough to explain the ~10x
slowdown.
Is there any expected drop in performance from using closure_convert,
perhaps given my previous statements on the complexity of the
setup_objective function?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#31 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVUS4EAJRU62MJAD5H3UGF5XDANCNFSM5EGS4IYA>
.
|
I'm not sure what your def pipeline(param_for_grad, latent):
res = intermediary_step(param_for_grad, **kwargs)
def objective_fun(params, intermediary_result):
[...] # do not use param_for_grad or res here!
solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
return solver.run(init_params, intermediary_result=res * latent).params
jax.jacobian(pipeline)(param_for_grad, latent) The key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea. |
Any follow up on this? Does your objective seem decomposable in the way I describe? |
Have just thought about this a bit -- I'm not 100% if this would work, but one potential resolution to this for us in terms of decomposing the problem could be to build our statistical model (expensive boilerplate) from which we want to call a def pipeline(param_for_grad):
res = model(param_for_grad)
def objective_fun(params, model):
return model.logpdf(params)
solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
return solver.run(init_params, model=model).params
jax.jacobian(pipeline)(param_for_grad) Provided this model was registered as a pytree, do you think this would resolve the problem? It's not something we've implemented, but could be if this would work. |
Just an update on this: we've managed to get the jit working on our side with If it's helpful, I'd be happy to summarise this thread as a small entry into the documentation via a PR @mblondel, since it could come up again for other users with similar use cases. Thanks both for the fast and attentive help! |
What would the code snippet look like? |
hi again @mblondel, sorry to resurrect this from the dead -- my solution that uses from functools import partial
import jax
import jax.numpy as jnp
import jaxopt
import optax
# dummy model for test purposes
class Model:
x: jax.Array
def __init__(self, x) -> None:
self.x = x
def logpdf(self, pars, data):
return jnp.sum(pars*data*self.x)
@partial(jax.jit, static_argnames=["objective_fn"])
def _minimize(
objective_fn,
init_pars,
lr,
):
# this is the line added from our discussion above
converted_fn, aux_pars = jax.closure_convert(objective_fn, init_pars)
# aux_pars seems to be empty -- would have assumed it was the closed-over vals or similar?
solver = jaxopt.OptaxSolver(
fun=converted_fn, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, *aux_pars)[0]
@partial(jax.jit, static_argnames=["model"])
def fit(
data,
model,
init_pars,
lr = 1e-3,
):
def fit_objective(pars):
return -model.logpdf(pars, data)
fit_res = _minimize(fit_objective, init_pars, lr)
return fit_res
def pipeline(x):
model = Model(x)
mle_pars = fit(
model=model,
data=jnp.array([5.0, 5.0]),
init_pars=jnp.array([1.0, 1.1]),
lr=1e-3,
)
return mle_pars
jax.jacrev(pipeline)(jnp.asarray(0.5))
# >> JaxStackTraceBeforeTransformation (this is Another thing to note: the I can't see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid |
I explored this a bit, and my particular workflow here was made possible if one makes the import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from simple_pytree import Pytree
# dummy model for test purposes
class Model(Pytree):
x: jax.Array
def __init__(self, x) -> None:
self.x = x
def logpdf(self, pars, data):
return jsp.stats.norm.logpdf(data, loc=pars*self.x, scale=1.0).sum()
@jax.jit
def pipeline(param_for_grad):
data=jnp.array([5.0, 5.0])
init_pars=jnp.array([1.0, 1.1])
lr=1e-3
model = Model(param_for_grad)
def fit(pars, model, data):
def fit_objective(pars, model, data):
return -model.logpdf(pars, data)
solver = jaxopt.OptaxSolver(
fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, model=model, data=data)[0]
return fit(init_pars, model, data)
jax.jacrev(pipeline)(jnp.asarray(0.5))
# > Array([-1.33830826e+01, 7.10542736e-15], dtype=float64, weak_type=True) Don't know if there's another potential issue above that i'm smearing over with this approach, but it works without |
I think you're doing the right thing by making the model a PyTree, i.e. I don't think you're smearing over any issue. This is the same approach Equinox uses ubiquitously, and this handles all the complexity of Diffrax just fine! |
I've been trying to use
OptaxSolver
to perform a simple function minimization, since I want to differentiate through it's solution (the fixed point of the solver), but ran into an issue I'm not familiar with.Here's a MWE for the error message:
which yields this error:
My versions are:
Am I doing something very silly? I guess I'm also wondering if this example within the scope of the solver API? I noticed that this doesn't occur with
solver.update
, just withsolver.run
.Thanks :)
The text was updated successfully, but these errors were encountered: