Skip to content

Commit

Permalink
Add helper to build hessian vector product
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Seyboldt <[email protected]>
  • Loading branch information
ricardoV94 and aseyboldt committed Jul 8, 2024
1 parent db1c161 commit 5fd729d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
10 changes: 10 additions & 0 deletions doc/tutorial/gradients.rst
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ or, making use of the R-operator:
>>> f([4, 4], [2, 2])
array([ 4., 4.])

There is a builtin helper that uses the first method

>>> x = pt.dvector('x')
>>> v = pt.dvector('v')
>>> y = pt.sum(x ** 2)
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
>>> f = pytensor.function([x, v], Hv)
>>> f([4, 4], [2, 2])
array([ 4., 4.])


Final Pointers
==============
Expand Down
79 changes: 79 additions & 0 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,85 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
return as_list_or_tuple(using_list, using_tuple, hessians)


def hessian_vector_product(cost, wrt, p, **grad_kwargs):
"""Return the expression of the Hessian times a vector p.
Notes
-----
This function uses backward autodiff twice to obtain the desired expression.
You may want to manually build the equivalent expression by combining backward
followed by forward (if all Ops support it) autodiff.
See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this.
Parameters
----------
cost: Scalar (0-dimensional) variable.
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
Each vector will be used for the hessp wirt to exach input variable
**grad_kwargs:
Keyword arguments passed to `grad` function.
Returns
-------
:class:` Vector or list of Vectors
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
Examples
--------
.. testcode::
import numpy as np
from scipy.optimize import minimize
from pytensor import function
from pytensor.tensor import vector
from pytensor.gradient import grad, hessian_vector_product
x = vector('x')
p = vector('p')
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
rosen_jac = grad(rosen, x)
rosen_hessp = hessian_vector_product(rosen, x, p)
rosen_fn = function([x], rosen)
rosen_jac_fn = function([x], rosen_jac)
rosen_hessp_fn = function([x, p], rosen_hessp)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
res = minimize(
rosen_fn,
x0,
method="Newton-CG",
jac=rosen_jac_fn,
hessp=rosen_hessp_fn,
options={"xtol": 1e-8},
)
print(res.x)
.. testoutput::
[1. 1. 1. 0.99999999 0.99999999]
"""
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
p_list = p if isinstance(p, Sequence) else [p]
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
hessian_cost = pytensor.tensor.add(
*[
(grad_wrt * p).sum()
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
]
)
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)

if isinstance(wrt, Variable):
return Hp_list[0]
return Hp_list


def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
Expand Down
39 changes: 39 additions & 0 deletions tests/test_gradient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from scipy.optimize import rosen_hess_prod

import pytensor
import pytensor.tensor.basic as ptb
Expand All @@ -20,6 +21,7 @@
grad_scale,
grad_undefined,
hessian,
hessian_vector_product,
jacobian,
subgraph_grad,
zero_grad,
Expand Down Expand Up @@ -1079,3 +1081,40 @@ def test_jacobian_disconnected_inputs():
func_s = pytensor.function([s2], jacobian_s)
val = np.array(1.0).astype(pytensor.config.floatX)
assert np.allclose(func_s(val), np.zeros(1))


class TestHessianVectorProdudoct:
def test_rosen(self):
x = vector("x", dtype="float64")
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()

p = vector("p", dtype="float64")
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)

x_test = 0.1 * np.arange(9)
p_test = 0.5 * np.arange(9)
np.testing.assert_allclose(
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
rosen_hess_prod(x_test, p_test),
)

def test_multiple_wrt(self):
x = vector("x", dtype="float64")
y = vector("y", dtype="float64")
p_x = vector("p_x", dtype="float64")
p_y = vector("p_y", dtype="float64")

cost = (x**2 - y**2).sum()
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])

hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
test = {
# x, y don't matter
"x": np.full((3,), np.nan),
"y": np.full((3,), np.nan),
"p_x": [1, 2, 3],
"p_y": [3, 2, 1],
}
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])

0 comments on commit 5fd729d

Please sign in to comment.