Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.root, a primitive for differentiable root finding #1339

Merged
merged 24 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 0 additions & 53 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,56 +1809,3 @@ def abstractify(x):
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
return tree_unflatten(out_tree(), out)


def _custom_implicit_solve(solve, tangent_solve):
"""Define gradients for a function that performs an implicit solve.

Note: this isn't ready for widespread use yet -- it does not handle closed
over values inside solve yet.

Args:
solve: callable that takes two positional arguments, func and params, and
returns a solution such that func(params, solution) = 0. In other words,
the following is assumed to be true (but not checked):
solution = solve(func, params)
error = func(solution, params)
assert tree_all(tree_map(partial(np.allclose, 0.0), error)
tangent_solve: callable that takes two positional arguments, a linear
function ``f`` and (possibly nested) array(s) ``y``, and returns a
solution ``x`` such that ``f(x)=y``:

- For scalar ``y``, use ``lambda f, y: y / f(1.0)``.
- For vector ``y``, you could use a linear solve with the Jacobian, if
dimensionality of ``y`` is not too large:
``lambda f, y: np.linalg.solve(jacobian(f)(y), y)``.

Returns:
Wrapped version of solve with JVP and VJPs defined with respect to
``params`` via implicit differentaion, rather than differntiating through
the solve.
"""
@wraps(solve)
def wrapper(func, params):

@custom_transforms
def solve_impl(params):
return solve(func, params)

@partial(defjvp_all, solve_impl)
def solve_impl_jvp(primals, tangents):
# F(u(m), m) = 0 # system of equations in m
# ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0
# ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m)
params, = primals
grad_params, = tangents
solution = solve_impl(params)
unchecked_zeros, f_jvp = vjp(func, solution, params)
grad_solution = tree_map(
lambda x: -x,
tangent_solve(lambda p: f_jvp(p)[0], f_jvp(grad_params)[1])
)
return solution, grad_solution

return solve_impl(params)
return wrapper
122 changes: 120 additions & 2 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -31,7 +32,7 @@
from jax.lax import lax
from jax import linear_util as lu
from jax.abstract_arrays import ShapedArray, raise_to_shaped
from jax.api_util import flatten_fun_nokwargs
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
Expand All @@ -42,7 +43,7 @@
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
split_dict, cache)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children)
treedef_children, tree_map)
from jax import ad_util

_map = safe_map
Expand Down Expand Up @@ -822,3 +823,120 @@ def body(i, dst):
return fori_loop(0, num, body, dst)

masking.masking_rules[lax.concatenate_p] = _concat_masking_rule


def root(f, initial_guess, solve, tangent_solve):
"""Differentiably solve for a roots of a function.

This is a low-level routine, mostly intended for internal use in JAX.
Gradients of root() are defined with respect to closed-over variables from
the provided function f.

Args:
f: function for which to find a root. Should accept a single argument,
return a tree of arrays with the same structure as its input.
initial_guess: initial guess for a zero of f.
solve: function to solve for the roots of f. Should take two positional
arguments, f and initial_guess, and return a solution with the same
structure as initial_guess such that func(solution) = 0. In other words,
the following is assumed to be true (but not checked)::

solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)

tangent_solve: function to solve the tangent system. Should take two
positional arguments, a linear function ``g`` (the function ``f``
linearized at its root) and a tree of array(s) ``y`` with the same
structure as initial_guess, and return a solution ``x`` such that
``g(x)=y``:

- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
- For vector ``y``, you could use a linear solve with the Jacobian, if
dimensionality of ``y`` is not too large:
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.

Returns:
The result of calling solve(f, initial_guess) with gradients defined via
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals)
in_tree, = treedef_children(in_args_tree)
if in_tree != out_tree:
raise TypeError(
"f output pytree structure must match initial_guess, got {} and {}."
.format(out_tree, in_tree)
)
out_flat = root_p.bind(*itertools.chain(consts, guess_flat),
tree=out_tree, num_consts=len(consts),
jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve)
return tree_unflatten(out_tree, out_flat)


def _root_abstract_eval(*args, **kwargs):
return args[kwargs['num_consts']:]


def _root_impl(*args, **kwargs):
tree, num_consts, jaxpr, solve, _ = split_dict(
kwargs, ['tree', 'num_consts', 'jaxpr', 'solve', 'tangent_solve'])

f = partial(
apply_flat_fun_nokwargs,
partial(core.jaxpr_as_fun(jaxpr), *args[:num_consts]),
(tree, tree),
)
initial_guess = tree_unflatten(tree, args[num_consts:])
out = solve(f, initial_guess)

out_flat, out_tree = tree_flatten(out)
if out_tree != tree:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is redundant because you already checked it when you formed jaxpr. It doesn't hurt to include though, other than taking up precious vertical space :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We evaluated f() when we formed the jaxpr, but not solve(). So I think we do need this. Actually I even wrote a test for that catches this error message :)

raise TypeError(
"solve output pytree structure must match initial_guess, got {} and {}"
.format(out_tree, tree))

return out_flat


def _root_jvp(
primals, tangents, tree, num_consts, jaxpr, solve, tangent_solve):
solution = root_p.bind(*primals, tree=tree, num_consts=num_consts,
jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve)

# F(u(m), m) = 0 # system of equations in m
# ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0
# ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m)
unchecked_zeros, f_jvp = ad.vjp(
shoyer marked this conversation as resolved.
Show resolved Hide resolved
lu.wrap_init(core.jaxpr_as_fun(jaxpr)),
primals[:num_consts] + tuple(solution)
)
f_linearized = partial(
apply_flat_fun_nokwargs,
lambda *xs: f_jvp(*xs)[num_consts:],
(tree, tree),
)
params_jvp = tree_unflatten(
tree, f_jvp(*tangents[:num_consts])[:num_consts])
negative_grad = tangent_solve(f_linearized, params_jvp)

negative_grad_flat, out_tree = tree_flatten(negative_grad)
if out_tree != tree:
raise TypeError(
"tangent_solve output pytree structure must match initial_guess, "
"got {} and {}".format(out_tree, tree))

grad_solution = _map(operator.neg, negative_grad_flat)
return solution, grad_solution

def _root_batch(args, dims, **params):
return batching.batch_fun(lu.wrap_init(_root_impl, params), args, dims)


root_p = core.Primitive('root')
root_p.multiple_results = True
root_p.def_impl(_root_impl)
ad.primitive_jvps[root_p] = _root_jvp
xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True)
batching.primitive_batchers[root_p] = _root_batch
35 changes: 0 additions & 35 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,41 +885,6 @@ def f(x):
xla_comp = api.xla_computation(f)
xla_comp(np.arange(8)).GetHloText() # doesn't crash

def test_custom_implicit_solve(self):

def scalar_solve(f, y):
return y / f(1.0)

def _binary_search(func, params, low=0.0, high=100.0, tolerance=1e-6):
def cond(state):
low, high = state
return high - low > tolerance

def body(state):
low, high = state
midpoint = 0.5 * (low + high)
update_upper = func(midpoint, params) > 0
low = np.where(update_upper, low, midpoint)
high = np.where(update_upper, midpoint, high)
return (low, high)

solution, _ = lax.while_loop(cond, body, (low, high))
return solution

binary_search = api._custom_implicit_solve(_binary_search, scalar_solve)
sqrt_cubed = lambda y, x: y ** 2 - x ** 3
value, grad = api.value_and_grad(binary_search, argnums=1)(sqrt_cubed, 5.0)
self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

def scalar_solve2(f, y):
y_1d = y[np.newaxis]
return np.linalg.solve(api.jacobian(f)(y_1d), y_1d).squeeze()

binary_search = api._custom_implicit_solve(_binary_search, scalar_solve2)
grad = api.grad(binary_search, argnums=1)(sqrt_cubed, 5.0)
self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

def test_jit_device_assignment(self):
raise unittest.SkipTest("Temporarily disabled while device API is being changed.")
device_num = xb.device_count() - 1
Expand Down
59 changes: 59 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,65 @@ def f(carry, _):
key = random.PRNGKey(0)
api.grad(lambda c: lax.scan(f, (c, key), onp.ones(3))[0][0])(0.) # doesn't crash

def test_root(self):

def scalar_solve(f, y):
return y / f(1.0)

def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
del x0 # unused

def cond(state):
low, high = state
return high - low > tolerance

def body(state):
low, high = state
midpoint = 0.5 * (low + high)
update_upper = func(midpoint) > 0
low = np.where(update_upper, low, midpoint)
high = np.where(update_upper, midpoint, high)
return (low, high)

solution, _ = lax.while_loop(cond, body, (low, high))
return solution

def sqrt_cubed(x, tangent_solve=scalar_solve):
f = lambda y: y ** 2 - x ** 3
return lax.root(f, 0.0, binary_search, tangent_solve)

value, grad = api.value_and_grad(sqrt_cubed)(5.0)
self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

inputs = np.array([4.0, 5.0])
results = api.vmap(sqrt_cubed)(inputs)
self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

def nd_solve(f, y):
g = lambda z: f(z.reshape(y.shape)).ravel()
jacobian = api.jacobian(g)(y.ravel())
return np.linalg.solve(jacobian, y.ravel()).reshape(y.shape)

sqrt_cubed_alt = partial(sqrt_cubed, tangent_solve=nd_solve)
value, grad = api.value_and_grad(sqrt_cubed_alt)(5.0)
self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

def test_root_errors(self):
with self.assertRaisesRegexp(TypeError, "f output pytree"):
lax.root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
with self.assertRaisesRegexp(TypeError, "solve output pytree"):
lax.root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)

def dummy_root_usage(x):
f = lambda y: x - y
return lax.root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))

with self.assertRaisesRegexp(TypeError, "tangent_solve output pytree"):
api.jvp(dummy_root_usage, (0.0,), (0.0,))


if __name__ == '__main__':
absltest.main()