-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.root, a primitive for differentiable root finding #1339
Conversation
jax/lax_linalg.py
Outdated
@@ -72,6 +72,9 @@ def svd(x, full_matrices=True, compute_uv=True): | |||
|
|||
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False, | |||
conjugate_a=False, unit_diagonal=False): | |||
# TODO(shoyer): remove this hack! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The need for this suggests that I'm doing something wrong...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zeros shouldn't escape the AD system, i.e. they should only appear in JVP rules. Maybe a check for ad_util.zero
just needs to be added upstream, i.e. in the caller (assuming this is called from a JVP rule).
# F(u(m), m) = 0 # system of equations in m | ||
# ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0 | ||
# ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m) | ||
unchecked_zeros, f_jvp = api.linearize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using ad.linearize
, but the jaxpr
it returns isn't a TypedJaxpr, so I couldn't figure out how to evaluate it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might want to use ad.jvp_jaxpr
here, basically following the logic of scan
, though I'm not sure yet if we need the fixed-point logic. I can explain more about that in chat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind about that fixed-point stuff; that's not relevant, because the variables we're differentiating with respect to are all in the closure of the function being passed in.
I'm paging this stuff back in, and more exited about it than ever. Here are some things I want to check with you: my current best understanding is we want a primitive
where the first argument is a function which we can map to math as f : R^n -> R^n, the second argument is an initial guess x_0 (a point in R^n), and the output x* is a point in R^n satisfying f(x*) = 0 (where the RHS is the zero vector in R^n). We expect to apply
where I'm using square brackets to denote the application of the linear function Does that sound right? |
Yes, this looks right to me, with one minor correction:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for pushing on this. We've got some more polishing work to do, but this is a big step forward: it actually works!
(We've been discussing extensively offline what follow-up stuff we want to do.)
out = solve(f, initial_guess) | ||
|
||
out_flat, out_tree = tree_flatten(out) | ||
if out_tree != tree: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check is redundant because you already checked it when you formed jaxpr
. It doesn't hurt to include though, other than taking up precious vertical space :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We evaluated f()
when we formed the jaxpr
, but not solve()
. So I think we do need this. Actually I even wrote a test for that catches this error message :)
core.jaxpr_as_fun(jaxpr), *(params + solution) | ||
) | ||
|
||
params_zeros = tuple(_map(ad_util.zeros_like_jaxval, params)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We think we can avoid instantiating zeros here, and otherwise be more conservative about how much work we do (I currently think we do want to run a fixed-point on ad.jvp_jaxpr
!), but we can leave that for follow-up work. Land and iterate!
This should solve the issue with closed over variables not being handled properly in
_custom_implicit_solve
.Does implementation look sane? It's a little messier than I would like, because we need to deal with inputs into provided functions.
See the tests for an example of what using this API looks like. There are lots of ways to make this more user-friendly, but this is intended as a low-level API for defining implicit derivatives. The user facing APIs will be routines like
scipy.optimize.root
, which won't require providing functions forsolve
andtangent_solve
.TODOs:
jtu.check_grads
jit
works properly