Skip to content

Commit

Permalink
Support vector lengthscales for RBF and Matern kernels (#1819)
Browse files Browse the repository at this point in the history
* Support vector lengthscales for RBF and Matern kernels

* Use broadcast_shapes to align params

* Remove union shorthand

* Update test/contrib/hsgp/test_approximation.py

Co-authored-by: Juan Orduz <[email protected]>

* Run make format

* Remove union in isinstance check

---------

Co-authored-by: Sam Anklesaria <[email protected]>
Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
3 people authored Jun 25, 2024
1 parent 209dad9 commit 2984b9b
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 32 deletions.
5 changes: 2 additions & 3 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

from __future__ import annotations

from typing import Union, get_args
from typing import get_args

from jaxlib.xla_extension import ArrayImpl
import numpy as np

import jax.numpy as jnp

ARRAY_TYPE = Union[ArrayImpl, np.ndarray]
from numpyro.contrib.hsgp.util import ARRAY_TYPE


def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:
Expand Down
22 changes: 14 additions & 8 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues


def align_param(dim, param):
return jnp.broadcast_to(param, jnp.broadcast_shapes(jnp.shape(param), (dim,)))


def spectral_density_squared_exponential(
dim: int, w: ArrayImpl, alpha: float, length: float
dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
"""
Spectral density of the squared exponential kernel.
Expand All @@ -44,13 +48,14 @@ def spectral_density_squared_exponential(
:return: spectral density value
:rtype: float
"""
c = alpha * (jnp.sqrt(2 * jnp.pi) * length) ** dim
e = jnp.exp(-0.5 * (length**2) * jnp.dot(w, w))
length = align_param(dim, length)
c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1)
e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2, axis=-1))
return c * e


def spectral_density_matern(
dim: int, nu: float, w: ArrayImpl, alpha: float, length: float
dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
"""
Spectral density of the Matérn kernel.
Expand Down Expand Up @@ -79,22 +84,23 @@ def spectral_density_matern(
:return: spectral density value
:rtype: float
""" # noqa: E501
length = align_param(dim, length)
c1 = (
alpha
* (2 ** (dim))
* (jnp.pi ** (dim / 2))
* ((2 * nu) ** nu)
* special.gamma(nu + dim / 2)
)
c2 = (2 * nu / (length**2) + jnp.dot(w, w)) ** (-nu - dim / 2)
c3 = special.gamma(nu) * length ** (2 * nu)
s = jnp.sum(length**2 * w**2, axis=-1)
c2 = jnp.prod(length, axis=-1) * (2 * nu + s) ** (-nu - dim / 2)
c3 = special.gamma(nu)
return c1 * c2 / c3


# TODO support length-D kernel hyperparameters
def diag_spectral_density_squared_exponential(
alpha: float,
length: float,
length: float | list[float],
ell: float | int | list[float | int],
m: int | list[int],
dim: int,
Expand Down
10 changes: 10 additions & 0 deletions numpyro/contrib/hsgp/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import numpy as np

import jax

ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers
115 changes: 94 additions & 21 deletions test/contrib/hsgp/test_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from functools import reduce
from operator import mul
from typing import Literal
from typing import Literal, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -74,23 +74,54 @@ def synthetic_two_dim_data() -> tuple[ArrayImpl, ArrayImpl]:


@pytest.mark.parametrize(
argnames="x1, x2, length, ell",
argnames="x1, x2, length, ell, xfail",
argvalues=[
(np.array([[1.0]]), np.array([[0.0]]), np.array([1.0]), 5.0),
(np.array([[1.0]]), np.array([[0.0]]), 1.0, 5.0, False),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0]),
1.0,
5.0,
False,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array([1.0, 0.5]),
5.0,
False,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
False,
),
(
np.array([[1.5, 1.25, 1.0]]),
np.array([[0.0, 0.0, 0.0]]),
np.array([[1.0, 0.5], [0.5, 1.0]]), # invalid length scale
5.0,
True,
),
],
ids=[
"1d",
"2d,1d-length",
"1d,scalar-length",
"2d,scalar-length",
"2d,vector-length",
"2d,matrix-length",
"2d,invalid-length",
],
)
def test_kernel_approx_squared_exponential(
x1: ArrayImpl, x2: ArrayImpl, length: ArrayImpl, ell: float
x1: ArrayImpl,
x2: ArrayImpl,
length: Union[float, ArrayImpl],
ell: float,
xfail: bool,
):
"""ensure that the approximation of the squared exponential kernel is accurate,
matching the exact kernel implementation from sklearn.
Expand All @@ -100,13 +131,26 @@ def test_kernel_approx_squared_exponential(
assert x1.shape == x2.shape
m = 100 # large enough to ensure the approximation is accurate
dim = x1.shape[-1]
if xfail:
with pytest.raises(ValueError):
diag_spectral_density_squared_exponential(1.0, length, ell, m, dim)
return
spd = diag_spectral_density_squared_exponential(1.0, length, ell, m, dim)

eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
approx = (eig_f1 * eig_f2) @ spd
exact = RBF(length)(x1, x2)
assert jnp.isclose(approx, exact, rtol=1e-3)

def _exact_rbf(length):
return RBF(length)(x1, x2).squeeze(axis=-1)

if isinstance(length, int) | isinstance(length, float):
exact = _exact_rbf(length)
elif length.ndim == 1:
exact = _exact_rbf(length)
else:
exact = np.apply_along_axis(_exact_rbf, axis=0, arr=length)
assert jnp.isclose(approx, exact, rtol=1e-3).all()


@pytest.mark.parametrize(
Expand All @@ -118,14 +162,32 @@ def test_kernel_approx_squared_exponential(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
np.array([1.0]),
np.array([0.25, 0.5]),
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
np.array([0.25, 0.5]),
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
3 / 2,
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
),
(
np.array([[1.5, 1.25]]),
np.array([[0.0, 0.0]]),
5 / 2,
np.array([1.0]),
np.array(
[[1.0, 0.5], [0.5, 1.0]]
), # different length scale for each point/dimension
5.0,
),
],
Expand All @@ -134,6 +196,8 @@ def test_kernel_approx_squared_exponential(
"1d,nu=5/2",
"2d,nu=3/2,1d-length",
"2d,nu=5/2,1d-length",
"2d,nu=3/2,2d-length",
"2d,nu=5/2,2d-length",
],
)
def test_kernel_approx_squared_matern(
Expand All @@ -154,8 +218,17 @@ def test_kernel_approx_squared_matern(
eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
approx = (eig_f1 * eig_f2) @ spd
exact = Matern(length_scale=length, nu=nu)(x1, x2)
assert jnp.isclose(approx, exact, rtol=1e-3)

def _exact_matern(length):
return Matern(length_scale=length, nu=nu)(x1, x2).squeeze(axis=-1)

if isinstance(length, float) | isinstance(length, int):
exact = _exact_matern(length)
elif length.ndim == 1:
exact = _exact_matern(length)
else:
exact = np.apply_along_axis(_exact_matern, axis=0, arr=length)
assert jnp.isclose(approx, exact, rtol=1e-3).all()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -211,8 +284,8 @@ def test_approximation_squared_exponential(
x: ArrayImpl,
alpha: float,
length: float,
ell: int | float | list[int | float],
m: int | list[int],
ell: Union[int, float, list[Union[int, float]]],
m: Union[int, list[int]],
non_centered: bool,
):
def model(x, alpha, length, ell, m, non_centered):
Expand Down Expand Up @@ -263,8 +336,8 @@ def test_approximation_matern(
nu: float,
alpha: float,
length: float,
ell: int | float | list[int | float],
m: int | list[int],
ell: Union[int, float, list[Union[int, float]]],
m: Union[int, list[int]],
non_centered: bool,
):
def model(x, nu, alpha, length, ell, m, non_centered):
Expand Down Expand Up @@ -306,8 +379,8 @@ def model(x, nu, alpha, length, ell, m, non_centered):
def test_squared_exponential_gp_model(
synthetic_one_dim_data,
synthetic_two_dim_data,
ell: float | int | list[float | int],
m: int | list[int],
ell: Union[float, int, list[Union[float, int]]],
m: Union[int, list[int]],
non_centered: bool,
num_dim: Literal[1, 2],
):
Expand Down Expand Up @@ -364,8 +437,8 @@ def test_matern_gp_model(
synthetic_one_dim_data,
synthetic_two_dim_data,
nu: float,
ell: int | float | list[float | int],
m: int | list[int],
ell: Union[int, float, list[Union[float, int]]],
m: Union[int, list[int]],
non_centered: bool,
num_dim: Literal[1, 2],
):
Expand Down

0 comments on commit 2984b9b

Please sign in to comment.