From 5fd729d0f7690140302728484b45bb3ec015b809 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 26 Jun 2024 13:09:08 +0200 Subject: [PATCH] Add helper to build hessian vector product Co-authored-by: Adrian Seyboldt --- doc/tutorial/gradients.rst | 10 +++++ pytensor/gradient.py | 79 ++++++++++++++++++++++++++++++++++++++ tests/test_gradient.py | 39 +++++++++++++++++++ 3 files changed, 128 insertions(+) diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index 28cdda7165..edb38bb018 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -267,6 +267,16 @@ or, making use of the R-operator: >>> f([4, 4], [2, 2]) array([ 4., 4.]) +There is a builtin helper that uses the first method + +>>> x = pt.dvector('x') +>>> v = pt.dvector('v') +>>> y = pt.sum(x ** 2) +>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v) +>>> f = pytensor.function([x, v], Hv) +>>> f([4, 4], [2, 2]) +array([ 4., 4.]) + Final Pointers ============== diff --git a/pytensor/gradient.py b/pytensor/gradient.py index c5823fa068..bf05a0f392 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -2050,6 +2050,85 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): return as_list_or_tuple(using_list, using_tuple, hessians) +def hessian_vector_product(cost, wrt, p, **grad_kwargs): + """Return the expression of the Hessian times a vector p. + + Notes + ----- + This function uses backward autodiff twice to obtain the desired expression. + You may want to manually build the equivalent expression by combining backward + followed by forward (if all Ops support it) autodiff. + See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this. + + Parameters + ---------- + cost: Scalar (0-dimensional) variable. + wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors + p: Vector (1-dimensional tensor) 'Variable' or list of Vectors + Each vector will be used for the hessp wirt to exach input variable + **grad_kwargs: + Keyword arguments passed to `grad` function. + + Returns + ------- + :class:` Vector or list of Vectors + The Hessian times p of the `cost` with respect to (elements of) `wrt`. + + Examples + -------- + + .. testcode:: + + import numpy as np + from scipy.optimize import minimize + from pytensor import function + from pytensor.tensor import vector + from pytensor.gradient import grad, hessian_vector_product + + x = vector('x') + p = vector('p') + + rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum() + rosen_jac = grad(rosen, x) + rosen_hessp = hessian_vector_product(rosen, x, p) + + rosen_fn = function([x], rosen) + rosen_jac_fn = function([x], rosen_jac) + rosen_hessp_fn = function([x, p], rosen_hessp) + x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2]) + res = minimize( + rosen_fn, + x0, + method="Newton-CG", + jac=rosen_jac_fn, + hessp=rosen_hessp_fn, + options={"xtol": 1e-8}, + ) + print(res.x) + + .. testoutput:: + + [1. 1. 1. 0.99999999 0.99999999] + + + + """ + wrt_list = wrt if isinstance(wrt, Sequence) else [wrt] + p_list = p if isinstance(p, Sequence) else [p] + grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs) + hessian_cost = pytensor.tensor.add( + *[ + (grad_wrt * p).sum() + for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True) + ] + ) + Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs) + + if isinstance(wrt, Variable): + return Hp_list[0] + return Hp_list + + def _is_zero(x): """ Returns 'yes', 'no', or 'maybe' indicating whether x diff --git a/tests/test_gradient.py b/tests/test_gradient.py index a92939369f..c45d07662d 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from scipy.optimize import rosen_hess_prod import pytensor import pytensor.tensor.basic as ptb @@ -20,6 +21,7 @@ grad_scale, grad_undefined, hessian, + hessian_vector_product, jacobian, subgraph_grad, zero_grad, @@ -1079,3 +1081,40 @@ def test_jacobian_disconnected_inputs(): func_s = pytensor.function([s2], jacobian_s) val = np.array(1.0).astype(pytensor.config.floatX) assert np.allclose(func_s(val), np.zeros(1)) + + +class TestHessianVectorProdudoct: + def test_rosen(self): + x = vector("x", dtype="float64") + rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum() + + p = vector("p", dtype="float64") + rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p) + + x_test = 0.1 * np.arange(9) + p_test = 0.5 * np.arange(9) + np.testing.assert_allclose( + rosen_hess_prod_pt.eval({x: x_test, p: p_test}), + rosen_hess_prod(x_test, p_test), + ) + + def test_multiple_wrt(self): + x = vector("x", dtype="float64") + y = vector("y", dtype="float64") + p_x = vector("p_x", dtype="float64") + p_y = vector("p_y", dtype="float64") + + cost = (x**2 - y**2).sum() + hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y]) + + hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y]) + test = { + # x, y don't matter + "x": np.full((3,), np.nan), + "y": np.full((3,), np.nan), + "p_x": [1, 2, 3], + "p_y": [3, 2, 1], + } + hessp_x_eval, hessp_y_eval = hessp_fn(**test) + np.testing.assert_allclose(hessp_x_eval, [2, 4, 6]) + np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])