diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 99fcc2c78..1ff4fa089 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -7,11 +7,15 @@ from __future__ import annotations +from typing import Union, get_args + from jaxlib.xla_extension import ArrayImpl +import numpy as np -import jax import jax.numpy as jnp +ARRAY_TYPE = Union[ArrayImpl, np.ndarray] + def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: """Returns the indices of the first :math:`D \\times m^\\star` eigenvalues of the laplacian operator. @@ -210,7 +214,7 @@ def _convert_ell( "The length of ell must be equal to the dimension of the space." ) ell_ = jnp.array(ell)[..., None] # dim x 1 array - elif isinstance(ell, jax.Array): + elif isinstance(ell, get_args(ARRAY_TYPE)): ell_ = ell if ell_.shape != (dim, 1): raise ValueError("ell must be a scalar or a list of length `dim`.") diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index 8befff885..0d4d3db3d 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -61,7 +61,7 @@ def spectral_density_matern( S(\\boldsymbol{\\omega}) = \\alpha \\frac{2^{D} \\pi^{D/2} \\Gamma(\\nu + D/2) (2 \\nu)^{\\nu}}{\\Gamma(\\nu) \\ell^{2 \\nu}} - \\left(\\frac{2 \\nu}{\\ell^2} + 4 \\pi^2 \\boldsymbol{\\omega}^{T} \\boldsymbol{\\omega}\\right)^{-\\nu - D/2} + \\left(\\frac{2 \\nu}{\\ell^2} + \\boldsymbol{\\omega}^{T} \\boldsymbol{\\omega}\\right)^{-\\nu - D/2} **References:** @@ -86,7 +86,7 @@ def spectral_density_matern( * ((2 * nu) ** nu) * special.gamma(nu + dim / 2) ) - c2 = ((2 * nu / (length**2)) + 4 * jnp.pi ** jnp.dot(w, w)) ** (-nu - dim / 2) + c2 = (2 * nu / (length**2) + jnp.dot(w, w)) ** (-nu - dim / 2) c3 = special.gamma(nu) * length ** (2 * nu) return c1 * c2 / c3 @@ -166,6 +166,7 @@ def modified_bessel_first_kind(v, z): ) from e v = jnp.asarray(v, dtype=float) + z = jnp.asarray(z, dtype=float) return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z) diff --git a/setup.py b/setup.py index d666574c1..0d9e4fb02 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ "ruff>=0.1.8", "pytest>=4.1", "pyro-api>=0.1.1", + "scikit-learn", "scipy>=1.9", ], "dev": [ diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 52aeb6f98..a941652b1 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -7,7 +7,9 @@ from operator import mul from typing import Literal +import numpy as np import pytest +from sklearn.gaussian_process.kernels import RBF, ExpSineSquared, Matern from jax import random from jax._src.array import ArrayImpl @@ -19,6 +21,12 @@ hsgp_periodic_non_centered, hsgp_squared_exponential, ) +from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic +from numpyro.contrib.hsgp.spectral_densities import ( + diag_spectral_density_matern, + diag_spectral_density_periodic, + diag_spectral_density_squared_exponential, +) import numpyro.distributions as dist from numpyro.handlers import scope, seed, trace @@ -65,13 +73,137 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: return generate_synthetic_two_dim_data(**kwargs) +@pytest.mark.parametrize( + argnames="x1, x2, length, ell", + argvalues=[ + (np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + np.array([1.0]), + 5.0, + ), + ], + ids=[ + "1d", + "2d,1d-length", + ], +) +def test_kernel_approx_squared_exponential( + x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float +): + """ensure that the approximation of the squared exponential kernel is accurate, + matching the exact kernel implementation from sklearn. + + See Riutort-Mayol 2023 equation (13) for the approximation formula. + """ + assert x1.shape == x2.shape + m = 100 # large enough to ensure the approximation is accurate + dim = x1.shape[-1] + spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) + + eig_f1 = eigenfunctions(x1, ell=ell, m=m) + eig_f2 = eigenfunctions(x2, ell=ell, m=m) + approx = (eig_f1 * eig_f2) @ spd + exact = RBF(length)(x1, x2) + assert jnp.isclose(approx, exact, rtol=1e-3) + + +@pytest.mark.parametrize( + argnames="x1, x2, nu, length, ell", + argvalues=[ + (np.array([[1.0]]), np.array([[0.0]]), 3 / 2, np.array([1.0]), 5.0), + (np.array([[1.0]]), np.array([[0.0]]), 5 / 2, np.array([1.0]), 5.0), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + 3 / 2, + np.array([1.0]), + 5.0, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + 5 / 2, + np.array([1.0]), + 5.0, + ), + ], + ids=[ + "1d,nu=3/2", + "1d,nu=5/2", + "2d,nu=3/2,1d-length", + "2d,nu=5/2,1d-length", + ], +) +def test_kernel_approx_squared_matern( + x1: ArrayImpl, x2: ArrayImpl, nu: float, length: ArrayImpl, ell: float +): + """ensure that the approximation of the matern kernel is accurate, + matching the exact kernel implementation from sklearn. + + See Riutort-Mayol 2023 equation (13) for the approximation formula. + """ + assert x1.shape == x2.shape + m = 100 # large enough to ensure the approximation is accurate + dim = x1.shape[-1] + spd = diag_spectral_density_matern( + nu=nu, alpha=1.0, length=length, ell=ell, m=m, dim=dim + ) + + eig_f1 = eigenfunctions(x1, ell=ell, m=m) + eig_f2 = eigenfunctions(x2, ell=ell, m=m) + approx = (eig_f1 * eig_f2) @ spd + exact = Matern(length_scale=length, nu=nu)(x1, x2) + assert jnp.isclose(approx, exact, rtol=1e-3) + + +@pytest.mark.parametrize( + argnames="x1, x2, w0, length", + argvalues=[ + (np.array([1.0]), np.array([0.0]), 1.0, 1.0), + (np.array([1.0]), np.array([0.0]), 1.5, 1.0), + ], + ids=[ + "1d,w0=1.0", + "1d,w0=1.5", + ], +) +def test_kernel_approx_periodic( + x1: ArrayImpl, + x2: ArrayImpl, + w0: float, + length: float, +): + """ensure that the approximation of the periodic kernel is accurate, + matching the exact kernel implementation from sklearn + + Note that the exact kernel implementation is parameterized with respect to the period, + and the periodicity is w0**(-1). We adjust the input values by dividing by 2*pi. + + See Riutort-Mayol 2023 appendix B for the approximation formula. + """ + assert x1.shape == x2.shape + m = 100 + q2 = diag_spectral_density_periodic(alpha=1.0, length=length, m=m) + q2_sine = jnp.concatenate([jnp.array([0.0]), q2[1:]]) + + cosines_f1, sines_f1 = eigenfunctions_periodic(x1, w0=w0, m=m) + cosines_f2, sines_f2 = eigenfunctions_periodic(x2, w0=w0, m=m) + approx = (cosines_f1 * cosines_f2) @ q2 + (sines_f1 * sines_f2) @ q2_sine + exact = ExpSineSquared(length_scale=length, periodicity=w0 ** (-1))( + x1[..., None] / (2 * jnp.pi), x2[..., None] / (2 * jnp.pi) + ) + assert jnp.isclose(approx, exact, rtol=1e-3) + + @pytest.mark.parametrize( argnames="x, alpha, length, ell, m, non_centered", argvalues=[ - (jnp.linspace(0, 1, 10), 1.0, 0.2, 12, 10, True), - (jnp.linspace(0, 1, 10), 1.0, 0.2, 12, 10, False), - (jnp.linspace(0, 10, 100), 3.0, 0.5, 120, 100, True), - (jnp.linspace(jnp.zeros(2), jnp.ones(2), 10), 1.0, 0.2, 12, [3, 3], True), + (np.linspace(0, 1, 10), 1.0, 0.2, 12, 10, True), + (np.linspace(0, 1, 10), 1.0, 0.2, 12, 10, False), + (np.linspace(0, 10, 100), 3.0, 0.5, 120, 100, True), + (np.linspace(np.zeros(2), np.ones(2), 10), 1.0, 0.2, 12, [3, 3], True), ], ids=["non_centered", "centered", "non_centered-large-domain", "non_centered-2d"], ) @@ -111,11 +243,11 @@ def model(x, alpha, length, ell, m, non_centered): @pytest.mark.parametrize( argnames="x, nu, alpha, length, ell, m, non_centered", argvalues=[ - (jnp.linspace(0, 1, 10), 3 / 2, 1.0, 0.2, 12, 10, True), - (jnp.linspace(0, 1, 10), 5 / 2, 1.0, 0.2, 12, 10, False), - (jnp.linspace(0, 10, 100), 7 / 2, 3.0, 0.5, 120, 100, True), + (np.linspace(0, 1, 10), 3 / 2, 1.0, 0.2, 12, 10, True), + (np.linspace(0, 1, 10), 5 / 2, 1.0, 0.2, 12, 10, False), + (np.linspace(0, 10, 100), 7 / 2, 3.0, 0.5, 120, 100, True), ( - jnp.linspace(jnp.zeros(2), jnp.ones(2), 10), + np.linspace(np.zeros(2), np.ones(2), 10), 3 / 2, 1.0, 0.2, @@ -289,9 +421,9 @@ def model(x, nu, ell, m, non_centered, y=None): @pytest.mark.parametrize( argnames="w0, m", argvalues=[ - (2 * jnp.pi / 7, 2), - (2 * jnp.pi / 10, 3), - (2 * jnp.pi / 5, 10), + (2 * np.pi / 7, 2), + (2 * np.pi / 10, 3), + (2 * np.pi / 5, 10), ], ids=["m=2", "m=3", "m=10"], ) diff --git a/test/contrib/hsgp/test_laplacian.py b/test/contrib/hsgp/test_laplacian.py index f7b79d295..2749a2c7d 100644 --- a/test/contrib/hsgp/test_laplacian.py +++ b/test/contrib/hsgp/test_laplacian.py @@ -6,6 +6,7 @@ from functools import reduce from operator import mul +import numpy as np import pytest from jax._src.array import ArrayImpl @@ -96,13 +97,13 @@ def test_sqrt_eigenvalues(ell: float | int, m: int | list[int], dim: int): @pytest.mark.parametrize( argnames="x, ell, m", argvalues=[ - (jnp.linspace(0, 1, 10), 1, 1), - (jnp.linspace(-1, 1, 10), 1, 21), - (jnp.linspace(-2, -1, 10), 2, 10), - (jnp.linspace(0, 100, 500), 120, 100), - (jnp.linspace(jnp.zeros(3), jnp.ones(3), 10), 2, [2, 2, 3]), + (np.linspace(0, 1, 10), 1, 1), + (np.linspace(-1, 1, 10), 1, 21), + (np.linspace(-2, -1, 10), 2, 10), + (np.linspace(0, 100, 500), 120, 100), + (np.linspace(np.zeros(3), np.ones(3), 10), 2, [2, 2, 3]), ( - jnp.linspace(jnp.zeros(3), jnp.ones(3), 100).reshape((10, 10, 3)), + np.linspace(np.zeros(3), np.ones(3), 100).reshape((10, 10, 3)), 2, [2, 2, 3], ), @@ -129,8 +130,8 @@ def test_eigenfunctions(x: ArrayImpl, ell: float | int, m: int | list[int]): (1, 1, False), (1, 2, False), ([1, 1], 2, False), - (jnp.array([1, 1])[..., None], 2, False), - (jnp.array([1, 1]), 2, True), + (np.array([1, 1])[..., None], 2, False), + (np.array([1, 1]), 2, True), ([1, 1], 1, True), ], ids=[ diff --git a/test/contrib/hsgp/test_spectral_densities.py b/test/contrib/hsgp/test_spectral_densities.py index 4794015e7..c51d9caf5 100644 --- a/test/contrib/hsgp/test_spectral_densities.py +++ b/test/contrib/hsgp/test_spectral_densities.py @@ -4,6 +4,7 @@ from functools import reduce from operator import mul +import numpy as np import pytest import jax.numpy as jnp @@ -22,8 +23,8 @@ argnames="dim, w, alpha, length", argvalues=[ (1, 0.1, 1.0, 0.2), - (2, jnp.array([0.1, 0.2]), 1.0, 0.2), - (3, jnp.array([0.1, 0.2, 0.3]), 1.0, 5.0), + (2, np.array([0.1, 0.2]), 1.0, 0.2), + (3, np.array([0.1, 0.2, 0.3]), 1.0, 5.0), ], ids=["dim=1", "dim=2", "dim=3"], ) @@ -39,8 +40,8 @@ def test_spectral_density_squared_exponential(dim, w, alpha, length): argnames="dim, nu, w, alpha, length", argvalues=[ (1, 3 / 2, 0.1, 1.0, 0.2), - (2, 5 / 2, jnp.array([0.1, 0.2]), 1.0, 0.2), - (3, 5 / 2, jnp.array([0.1, 0.2, 0.3]), 1.0, 5.0), + (2, 5 / 2, np.array([0.1, 0.2]), 1.0, 0.2), + (3, 5 / 2, np.array([0.1, 0.2, 0.3]), 1.0, 5.0), ], ids=["dim=1", "dim=2", "dim=3"], ) @@ -113,8 +114,8 @@ def test_modified_bessel_first_kind_one_dim(v, z): @pytest.mark.parametrize( argnames="v, z", argvalues=[ - (jnp.linspace(0.1, 1.0, 10), jnp.array([0.1])), - (jnp.linspace(0.1, 1.0, 10), jnp.linspace(0.1, 1.0, 10)), + (np.linspace(0.1, 1.0, 10), np.array([0.1])), + (np.linspace(0.1, 1.0, 10), np.linspace(0.1, 1.0, 10)), ], ids=["z=0.1", "z=0.2"], )