diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 1ff4fa089..16c1d6c31 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -7,14 +7,13 @@ from __future__ import annotations -from typing import Union, get_args +from typing import get_args from jaxlib.xla_extension import ArrayImpl -import numpy as np import jax.numpy as jnp -ARRAY_TYPE = Union[ArrayImpl, np.ndarray] +from numpyro.contrib.hsgp.util import ARRAY_TYPE def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: diff --git a/numpyro/contrib/hsgp/spectral_densities.py b/numpyro/contrib/hsgp/spectral_densities.py index 0d4d3db3d..4762d5340 100644 --- a/numpyro/contrib/hsgp/spectral_densities.py +++ b/numpyro/contrib/hsgp/spectral_densities.py @@ -16,8 +16,12 @@ from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues +def align_param(dim, param): + return jnp.broadcast_to(param, jnp.broadcast_shapes(jnp.shape(param), (dim,))) + + def spectral_density_squared_exponential( - dim: int, w: ArrayImpl, alpha: float, length: float + dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl ) -> float: """ Spectral density of the squared exponential kernel. @@ -44,13 +48,14 @@ def spectral_density_squared_exponential( :return: spectral density value :rtype: float """ - c = alpha * (jnp.sqrt(2 * jnp.pi) * length) ** dim - e = jnp.exp(-0.5 * (length**2) * jnp.dot(w, w)) + length = align_param(dim, length) + c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1) + e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2, axis=-1)) return c * e def spectral_density_matern( - dim: int, nu: float, w: ArrayImpl, alpha: float, length: float + dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl ) -> float: """ Spectral density of the Matérn kernel. @@ -79,6 +84,7 @@ def spectral_density_matern( :return: spectral density value :rtype: float """ # noqa: E501 + length = align_param(dim, length) c1 = ( alpha * (2 ** (dim)) @@ -86,15 +92,15 @@ def spectral_density_matern( * ((2 * nu) ** nu) * special.gamma(nu + dim / 2) ) - c2 = (2 * nu / (length**2) + jnp.dot(w, w)) ** (-nu - dim / 2) - c3 = special.gamma(nu) * length ** (2 * nu) + s = jnp.sum(length**2 * w**2, axis=-1) + c2 = jnp.prod(length, axis=-1) * (2 * nu + s) ** (-nu - dim / 2) + c3 = special.gamma(nu) return c1 * c2 / c3 -# TODO support length-D kernel hyperparameters def diag_spectral_density_squared_exponential( alpha: float, - length: float, + length: float | list[float], ell: float | int | list[float | int], m: int | list[int], dim: int, diff --git a/numpyro/contrib/hsgp/util.py b/numpyro/contrib/hsgp/util.py new file mode 100644 index 000000000..5afbcfb83 --- /dev/null +++ b/numpyro/contrib/hsgp/util.py @@ -0,0 +1,10 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import numpy as np + +import jax + +ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index a941652b1..79ec1dd88 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -5,7 +5,7 @@ from functools import reduce from operator import mul -from typing import Literal +from typing import Literal, Union import numpy as np import pytest @@ -74,23 +74,54 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]: @pytest.mark.parametrize( - argnames="x1, x2, length, ell", + argnames="x1, x2, length, ell, xfail", argvalues=[ - (np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0), + (np.array([[1.0]]), np.array([[0.0]]), 1.0, 5.0, False), ( np.array([[1.5, 1.25]]), np.array([[0.0, 0.0]]), - np.array([1.0]), + 1.0, + 5.0, + False, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + np.array([1.0, 0.5]), 5.0, + False, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + np.array( + [[1.0, 0.5], [0.5, 1.0]] + ), # different length scale for each point/dimension + 5.0, + False, + ), + ( + np.array([[1.5, 1.25, 1.0]]), + np.array([[0.0, 0.0, 0.0]]), + np.array([[1.0, 0.5], [0.5, 1.0]]), # invalid length scale + 5.0, + True, ), ], ids=[ - "1d", - "2d,1d-length", + "1d,scalar-length", + "2d,scalar-length", + "2d,vector-length", + "2d,matrix-length", + "2d,invalid-length", ], ) def test_kernel_approx_squared_exponential( - x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float + x1: ArrayImpl, + x2: ArrayImpl, + length: Union[float, ArrayImpl], + ell: float, + xfail: bool, ): """ensure that the approximation of the squared exponential kernel is accurate, matching the exact kernel implementation from sklearn. @@ -100,13 +131,26 @@ def test_kernel_approx_squared_exponential( assert x1.shape == x2.shape m = 100 # large enough to ensure the approximation is accurate dim = x1.shape[-1] + if xfail: + with pytest.raises(ValueError): + diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) + return spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim) eig_f1 = eigenfunctions(x1, ell=ell, m=m) eig_f2 = eigenfunctions(x2, ell=ell, m=m) approx = (eig_f1 * eig_f2) @ spd - exact = RBF(length)(x1, x2) - assert jnp.isclose(approx, exact, rtol=1e-3) + + def _exact_rbf(length): + return RBF(length)(x1, x2).squeeze(axis=-1) + + if isinstance(length, int) | isinstance(length, float): + exact = _exact_rbf(length) + elif length.ndim == 1: + exact = _exact_rbf(length) + else: + exact = np.apply_along_axis(_exact_rbf, axis=0, arr=length) + assert jnp.isclose(approx, exact, rtol=1e-3).all() @pytest.mark.parametrize( @@ -118,14 +162,32 @@ def test_kernel_approx_squared_exponential( np.array([[1.5, 1.25]]), np.array([[0.0, 0.0]]), 3 / 2, - np.array([1.0]), + np.array([0.25, 0.5]), + 5.0, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + 5 / 2, + np.array([0.25, 0.5]), + 5.0, + ), + ( + np.array([[1.5, 1.25]]), + np.array([[0.0, 0.0]]), + 3 / 2, + np.array( + [[1.0, 0.5], [0.5, 1.0]] + ), # different length scale for each point/dimension 5.0, ), ( np.array([[1.5, 1.25]]), np.array([[0.0, 0.0]]), 5 / 2, - np.array([1.0]), + np.array( + [[1.0, 0.5], [0.5, 1.0]] + ), # different length scale for each point/dimension 5.0, ), ], @@ -134,6 +196,8 @@ def test_kernel_approx_squared_exponential( "1d,nu=5/2", "2d,nu=3/2,1d-length", "2d,nu=5/2,1d-length", + "2d,nu=3/2,2d-length", + "2d,nu=5/2,2d-length", ], ) def test_kernel_approx_squared_matern( @@ -154,8 +218,17 @@ def test_kernel_approx_squared_matern( eig_f1 = eigenfunctions(x1, ell=ell, m=m) eig_f2 = eigenfunctions(x2, ell=ell, m=m) approx = (eig_f1 * eig_f2) @ spd - exact = Matern(length_scale=length, nu=nu)(x1, x2) - assert jnp.isclose(approx, exact, rtol=1e-3) + + def _exact_matern(length): + return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1) + + if isinstance(length, float) | isinstance(length, int): + exact = _exact_matern(length) + elif length.ndim == 1: + exact = _exact_matern(length) + else: + exact = np.apply_along_axis(_exact_matern, axis=0, arr=length) + assert jnp.isclose(approx, exact, rtol=1e-3).all() @pytest.mark.parametrize( @@ -211,8 +284,8 @@ def test_approximation_squared_exponential( x: ArrayImpl, alpha: float, length: float, - ell: int | float | list[int | float], - m: int | list[int], + ell: Union[int, float, list[Union[int, float]]], + m: Union[int, list[int]], non_centered: bool, ): def model(x, alpha, length, ell, m, non_centered): @@ -263,8 +336,8 @@ def test_approximation_matern( nu: float, alpha: float, length: float, - ell: int | float | list[int | float], - m: int | list[int], + ell: Union[int, float, list[Union[int, float]]], + m: Union[int, list[int]], non_centered: bool, ): def model(x, nu, alpha, length, ell, m, non_centered): @@ -306,8 +379,8 @@ def model(x, nu, alpha, length, ell, m, non_centered): def test_squared_exponential_gp_model( synthetic_one_dim_data, synthetic_two_dim_data, - ell: float | int | list[float | int], - m: int | list[int], + ell: Union[float, int, list[Union[float, int]]], + m: Union[int, list[int]], non_centered: bool, num_dim: Literal[1, 2], ): @@ -364,8 +437,8 @@ def test_matern_gp_model( synthetic_one_dim_data, synthetic_two_dim_data, nu: float, - ell: int | float | list[float | int], - m: int | list[int], + ell: Union[int, float, list[Union[float, int]]], + m: Union[int, list[int]], non_centered: bool, num_dim: Literal[1, 2], ):