Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.root, a primitive for differentiable root finding #1339

Merged
merged 24 commits into from
Sep 27, 2019

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented Sep 12, 2019

This should solve the issue with closed over variables not being handled properly in _custom_implicit_solve.

Does implementation look sane? It's a little messier than I would like, because we need to deal with inputs into provided functions.

See the tests for an example of what using this API looks like. There are lots of ways to make this more user-friendly, but this is intended as a low-level API for defining implicit derivatives. The user facing APIs will be routines like scipy.optimize.root, which won't require providing functions for solve and tangent_solve.

TODOs:

  • replace asserts with errors
  • more test coverage
    • something checking higher dimensional arrays / nested structures
    • check grads with jtu.check_grads
    • verify the jit works properly

jax/lax/lax_control_flow.py Outdated Show resolved Hide resolved
@@ -72,6 +72,9 @@ def svd(x, full_matrices=True, compute_uv=True):

def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
conjugate_a=False, unit_diagonal=False):
# TODO(shoyer): remove this hack!
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The need for this suggests that I'm doing something wrong...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zeros shouldn't escape the AD system, i.e. they should only appear in JVP rules. Maybe a check for ad_util.zero just needs to be added upstream, i.e. in the caller (assuming this is called from a JVP rule).

# F(u(m), m) = 0 # system of equations in m
# ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0
# ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m)
unchecked_zeros, f_jvp = api.linearize(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using ad.linearize, but the jaxpr it returns isn't a TypedJaxpr, so I couldn't figure out how to evaluate it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might want to use ad.jvp_jaxpr here, basically following the logic of scan, though I'm not sure yet if we need the fixed-point logic. I can explain more about that in chat.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind about that fixed-point stuff; that's not relevant, because the variables we're differentiating with respect to are all in the closure of the function being passed in.

tests/lax_control_flow_test.py Show resolved Hide resolved
@mattjj
Copy link
Collaborator

mattjj commented Sep 25, 2019

I'm paging this stuff back in, and more exited about it than ever.

Here are some things I want to check with you: my current best understanding is we want a primitive

root :: (a -> a) -> a -> a

where the first argument is a function which we can map to math as f : R^n -> R^n, the second argument is an initial guess x_0 (a point in R^n), and the output x* is a point in R^n satisfying f(x*) = 0 (where the RHS is the zero vector in R^n).

We expect to apply root to functions that have nontrivial closures, and in particular that might close over values involved in differentiation. So while the above is a description at the API level, really we should think of the mathematical function in a closure-converted way as f : R^n x R^p -> R^n, where the first argument is some set of parameters. Then we can say x*(a) solves f(x*(a), a) = 0.secondI think we want the JVP rule to look something like

root_jvp f x0 a adot =
  let x_star = root (lambda x: f(x, a)) x0
      x_star_dot = linear_solve (∂_0 f (x_star, a)) (∂_1 f(x_star, a)[-a_dot])
  in (x_star, x_star_dot)

where I'm using square brackets to denote the application of the linear function ∂_1 f(x_star, a) to a vector.

Does that sound right?

@shoyer
Copy link
Collaborator Author

shoyer commented Sep 25, 2019

Yes, this looks right to me, with one minor correction:

So while the above is a description at the API level, really we should think of the mathematical function in a closure-converted way as f : R^n x R^p -> R^n, where the firstsecond argument is some set of parameters

@mattjj mattjj self-assigned this Sep 27, 2019
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for pushing on this. We've got some more polishing work to do, but this is a big step forward: it actually works!

(We've been discussing extensively offline what follow-up stuff we want to do.)

out = solve(f, initial_guess)

out_flat, out_tree = tree_flatten(out)
if out_tree != tree:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is redundant because you already checked it when you formed jaxpr. It doesn't hurt to include though, other than taking up precious vertical space :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We evaluated f() when we formed the jaxpr, but not solve(). So I think we do need this. Actually I even wrote a test for that catches this error message :)

core.jaxpr_as_fun(jaxpr), *(params + solution)
)

params_zeros = tuple(_map(ad_util.zeros_like_jaxval, params))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We think we can avoid instantiating zeros here, and otherwise be more conservative about how much work we do (I currently think we do want to run a fixed-point on ad.jvp_jaxpr!), but we can leave that for follow-up work. Land and iterate!

@mattjj mattjj merged commit c33e8cb into jax-ml:master Sep 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants