-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.root, a primitive for differentiable root finding #1339
Changes from all commits
0c3e9ce
4952304
12509c9
ec192d0
d737cc9
e7b2037
9584dbe
0e5b5d3
5d3910f
db5b922
bfd70b9
ac90492
b9ff208
0f69e45
fd975b6
a44d8e9
7eaca4d
5ca9f0a
6fda7be
8b1a60e
854aa28
760d696
4b53ea2
41d8a72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
|
@@ -31,7 +32,7 @@ | |
from jax.lax import lax | ||
from jax import linear_util as lu | ||
from jax.abstract_arrays import ShapedArray, raise_to_shaped | ||
from jax.api_util import flatten_fun_nokwargs | ||
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs | ||
from jax.interpreters import ad | ||
from jax.interpreters import partial_eval as pe | ||
from jax.interpreters import xla | ||
|
@@ -42,7 +43,7 @@ | |
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list, | ||
split_dict, cache) | ||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, | ||
treedef_children) | ||
treedef_children, tree_map) | ||
from jax import ad_util | ||
|
||
_map = safe_map | ||
|
@@ -829,3 +830,127 @@ def body(i, dst): | |
return fori_loop(0, num, body, dst) | ||
|
||
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule | ||
|
||
|
||
def root(f, initial_guess, solve, tangent_solve): | ||
"""Differentiably solve for a roots of a function. | ||
|
||
This is a low-level routine, mostly intended for internal use in JAX. | ||
Gradients of root() are defined with respect to closed-over variables from | ||
the provided function f. | ||
|
||
Args: | ||
f: function for which to find a root. Should accept a single argument, | ||
return a tree of arrays with the same structure as its input. | ||
initial_guess: initial guess for a zero of f. | ||
solve: function to solve for the roots of f. Should take two positional | ||
arguments, f and initial_guess, and return a solution with the same | ||
structure as initial_guess such that func(solution) = 0. In other words, | ||
the following is assumed to be true (but not checked):: | ||
|
||
solution = solve(f, initial_guess) | ||
error = f(solution) | ||
assert all(error == 0) | ||
|
||
tangent_solve: function to solve the tangent system. Should take two | ||
positional arguments, a linear function ``g`` (the function ``f`` | ||
linearized at its root) and a tree of array(s) ``y`` with the same | ||
structure as initial_guess, and return a solution ``x`` such that | ||
``g(x)=y``: | ||
|
||
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``. | ||
- For vector ``y``, you could use a linear solve with the Jacobian, if | ||
dimensionality of ``y`` is not too large: | ||
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. | ||
|
||
Returns: | ||
The result of calling solve(f, initial_guess) with gradients defined via | ||
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. | ||
""" | ||
guess_flat, in_args_tree = tree_flatten((initial_guess,)) | ||
guess_avals = tuple(_map(_abstractify, guess_flat)) | ||
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals) | ||
in_tree, = treedef_children(in_args_tree) | ||
if in_tree != out_tree: | ||
raise TypeError( | ||
"f() output pytree structure must match initial_guess, got {} and {}." | ||
.format(out_tree, in_tree) | ||
) | ||
out_flat = root_p.bind(*itertools.chain(consts, guess_flat), | ||
tree=out_tree, num_consts=len(consts), | ||
jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve) | ||
return tree_unflatten(out_tree, out_flat) | ||
|
||
|
||
def _root_abstract_eval(*args, **kwargs): | ||
return args[kwargs['num_consts']:] | ||
|
||
|
||
def _root_impl(*args, **kwargs): | ||
tree, num_consts, jaxpr, solve, _ = split_dict( | ||
kwargs, ['tree', 'num_consts', 'jaxpr', 'solve', 'tangent_solve']) | ||
|
||
f = partial( | ||
apply_flat_fun_nokwargs, | ||
partial(core.jaxpr_as_fun(jaxpr), *args[:num_consts]), | ||
(tree, tree), | ||
) | ||
initial_guess = tree_unflatten(tree, args[num_consts:]) | ||
out = solve(f, initial_guess) | ||
|
||
out_flat, out_tree = tree_flatten(out) | ||
if out_tree != tree: | ||
raise TypeError( | ||
"solve() output pytree structure must match initial_guess, got {} and {}" | ||
.format(out_tree, tree)) | ||
|
||
return out_flat | ||
|
||
|
||
def _root_jvp( | ||
primals, tangents, tree, num_consts, jaxpr, solve, tangent_solve): | ||
params = primals[:num_consts] | ||
solution = tuple( | ||
root_p.bind(*primals, tree=tree, num_consts=num_consts, | ||
jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve) | ||
) | ||
|
||
params_dot = tangents[:num_consts] | ||
|
||
# F(u(m), m) = 0 # system of equations in m | ||
# ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0 | ||
# ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m) | ||
unchecked_zeros, f_jvp = api.linearize( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we might want to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind about that fixed-point stuff; that's not relevant, because the variables we're differentiating with respect to are all in the closure of the function being passed in. |
||
core.jaxpr_as_fun(jaxpr), *(params + solution) | ||
) | ||
|
||
params_zeros = tuple(_map(ad_util.zeros_like_jaxval, params)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We think we can avoid instantiating zeros here, and otherwise be more conservative about how much work we do (I currently think we do want to run a fixed-point on |
||
solution_zeros = tuple(_map(ad_util.zeros_like_jaxval, solution)) | ||
|
||
f_linearized_at_solution = partial( | ||
apply_flat_fun_nokwargs, partial(f_jvp, *params_zeros), (tree, tree), | ||
) | ||
rhs = tree_unflatten(tree, f_jvp(*(params_dot + solution_zeros))) | ||
solution_dot = tree_map( | ||
operator.neg, tangent_solve(f_linearized_at_solution, rhs) | ||
) | ||
|
||
solution_dot_flat, out_tree = tree_flatten(solution_dot) | ||
if out_tree != tree: | ||
raise TypeError( | ||
"tangent_solve() output pytree structure must match initial_guess, " | ||
"got {} and {}".format(out_tree, tree)) | ||
|
||
return solution, solution_dot_flat | ||
|
||
def _root_batch(args, dims, **params): | ||
return batching.batch_fun(lu.wrap_init(_root_impl, params), args, dims) | ||
|
||
|
||
root_p = core.Primitive('root') | ||
root_p.multiple_results = True | ||
root_p.def_impl(_root_impl) | ||
root_p.def_abstract_eval(_root_abstract_eval) | ||
ad.primitive_jvps[root_p] = _root_jvp | ||
xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True) | ||
batching.primitive_batchers[root_p] = _root_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check is redundant because you already checked it when you formed
jaxpr
. It doesn't hurt to include though, other than taking up precious vertical space :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We evaluated
f()
when we formed thejaxpr
, but notsolve()
. So I think we do need this. Actually I even wrote a test for that catches this error message :)