Skip to content

Commit

Permalink
Add helper to build hessian vector product
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Seyboldt <[email protected]>
  • Loading branch information
ricardoV94 and aseyboldt committed Jun 26, 2024
1 parent d3bd1f1 commit 93a768d
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 0 deletions.
10 changes: 10 additions & 0 deletions doc/tutorial/gradients.rst
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ or, making use of the R-operator:
>>> f([4, 4], [2, 2])
array([ 4., 4.])

There is a builtin helper that uses the first method

>>> x = pt.dvector('x')
>>> v = pt.dvector('v')
>>> y = pt.sum(x ** 2)
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
>>> f = pytensor.function([x, v], Hv)
>>> f([4, 4], [2, 2])
array([ 4., 4.])


Final Pointers
==============
Expand Down
76 changes: 76 additions & 0 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,82 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
return as_list_or_tuple(using_list, using_tuple, hessians)


def hessian_vector_product(cost, wrt, p, **grad_kwargs):
"""Return the expression of the Hessian times a vector p.
Notes
-----
This function uses backward autodiff twice to obtain the desired expression.
You may want to manually build the equivalent expression by combining backward
followed by forward (if all Ops support it) autodiff.
See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this.
Parameters
----------
cost: Scalar (0-dimensional) variable.
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
Each vector will be used for the hessp wirt to exach input variable
**grad_kwargs:
Keyword arguments passed to `grad` function.
Returns
-------
:class:` Vector or list of Vectors
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
Examples
--------
.. testcode::
import numpy as np
from scipy.optimize import minimize
from pytensor import function
from pytensor.tensor import vector
from pytensor.gradient import jacobian, hessian_vector_product
x = vector('x')
p = vector('p')
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
rosen_jac = jacobian(rosen, x)
rosen_hessp = hessian_vector_product(rosen, x, p)
rosen_fn = function([x], rosen)
rosen_jac_fn = function([x], rosen_jac)
rosen_hessp_fn = function([x, p], rosen_hessp)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
res = minimize(
rosen_fn,
x0,
method="Newton-CG",
jac=rosen_jac_fn,
hessp=rosen_hessp_fn,
options={"xtol": 1e-8},
)
assert res.success
np.testing.assert_allclose(res.x, np.ones_like(x0))
"""
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
p_list = p if isinstance(p, Sequence) else [p]
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
hessian_cost = pytensor.tensor.add(
*[
(grad_wrt * p).sum()
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
]
)
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)

if isinstance(wrt, Variable):
return Hp_list[0]
return Hp_list


def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
Expand Down
70 changes: 70 additions & 0 deletions tests/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from scipy.optimize import rosen_hess_prod

import pytensor
import pytensor.tensor.basic as ptb
Expand All @@ -22,6 +23,7 @@
grad_scale,
grad_undefined,
hessian,
hessian_vector_product,
jacobian,
subgraph_grad,
zero_grad,
Expand Down Expand Up @@ -1081,3 +1083,71 @@ def test_jacobian_disconnected_inputs():
func_s = pytensor.function([s2], jacobian_s)
val = np.array(1.0).astype(pytensor.config.floatX)
assert np.allclose(func_s(val), np.zeros(1))


class TestHessianVectorProdudoct:
def test_rosen(self):
x = vector("x", dtype="float64")
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()

p = vector("p", dtype="float64")
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)

x_test = 0.1 * np.arange(9)
p_test = 0.5 * np.arange(9)
np.testing.assert_allclose(
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
rosen_hess_prod(x_test, p_test),
)

def test_multiple_wrt(self):
x = vector("x", dtype="float64")
y = vector("y", dtype="float64")
p_x = vector("p_x", dtype="float64")
p_y = vector("p_y", dtype="float64")

cost = (x**2 - y**2).sum()
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])

hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
test = {
# x, y don't matter
"x": np.full((3,), np.nan),
"y": np.full((3,), np.nan),
"p_x": [1, 2, 3],
"p_y": [3, 2, 1],
}
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])

@pytest.mark.skipif(config.floatX == "float32", reason="No point")
def test_doc_example(self):
import numpy as np
from scipy.optimize import minimize

from pytensor import function
from pytensor.gradient import hessian_vector_product, jacobian
from pytensor.tensor import vector

x = vector("x")
p = vector("p")

rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
rosen_jac = jacobian(rosen, x)
rosen_hessp = hessian_vector_product(rosen, x, p)

rosen_fn = function([x], rosen)
rosen_jac_fn = function([x], rosen_jac)
rosen_hessp_fn = function([x, p], rosen_hessp)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
res = minimize(
rosen_fn,
x0,
method="Newton-CG",
jac=rosen_jac_fn,
hessp=rosen_hessp_fn,
options={"xtol": 1e-8},
)
assert res.success
np.testing.assert_allclose(res.x, np.ones_like(x0))

0 comments on commit 93a768d

Please sign in to comment.