Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contrib.hsgp: support vector-valued kernel hyperparameters #1805

Closed
brendancooley opened this issue May 22, 2024 · 3 comments · Fixed by #1819
Closed

contrib.hsgp: support vector-valued kernel hyperparameters #1805

brendancooley opened this issue May 22, 2024 · 3 comments · Fixed by #1819
Labels
enhancement New feature or request

Comments

@brendancooley
Copy link
Contributor

#1803 implements support for multidimensional Hilbert Space Gaussian Process approximations. However, it only supports estimation of a single set of kernel hyperparameters (e.g. squared exponential lengthscale). In principal, the lengthscale can vary across dimensions of the input space (see Riutort-Mayol 2022 eq 1, 2, and 3 for the associated spectral density functions).

Implementation requires updating the spectral_density_matern and spectral_density_squared_exponential (numpyro.contrib.hsgp.spectral_densities) to accept and process array-valued inputs for the length parameter (this may also require upstream changes to the vmap in diag_spectral_density_squared_exponential and diag_spectral_density_matern). The test models test_squared_exponential_gp_model and test_matern_gp_model (test.contrib.hsgp.test_approximation) should be updated to optionally sample vector-valued lengthscales and test cases demonstrating the functionality should be created.

@brendancooley
Copy link
Contributor Author

Notes to self as I start working on this...

We might consider first adding some tests that ensure that our kernel approximations come close to matching the exact versions for m large enough. Something like the following:

from sklearn.gaussian_process.kernels import RBF

import jax.numpy as jnp

from numpyro.contrib.hsgp.laplacian import eigenfunctions, sqrt_eigenvalues
from numpyro.contrib.hsgp.spectral_densities import (
    diag_spectral_density_squared_exponential,
)

x1 = jnp.array([1.0, 1.0])[None, ...]
x2 = jnp.array([0.0, 0.0])[None, ...]
m = 10
ell = 3
sqrt_eig_v = sqrt_eigenvalues(ell=ell, m=m, dim=2)
eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
spd = diag_spectral_density_squared_exponential(1.0, 1.0, ell, m, 2)[None, ...]
approx = (eig_f1 * eig_f2 * spd).sum(axis=1)
exact = RBF(1.0)(x1, x2)
assert jnp.isclose(approx, exact)

@samanklesaria
Copy link
Contributor

samanklesaria commented Jun 12, 2024

I'd be interested in finishing this up if you're not too far along! For the squared exponential at least, I'd assume we can do something like the following:

def spectral_density_squared_exponential(
   dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
    ...
    length = jnp.broadcast_to(length, dim)
    c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length)
    e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2))
    return c * e

This would allow for the current behavior, but also let us have different length-scales for each dimension.

I could add tests like you describe above to the test/contrib/hsgp/test_approximation.py file.

@brendancooley
Copy link
Contributor Author

@samanklesaria go for it! Perhaps we can swap ideas and merge implementations. I have a little bit of wip here. I would like to try and support batch dimensions on the lengthscale to enable batched approx GPs, in addition lengthscale heterogeneity within a single GP. Just need to work out the API a bit. I have a working example with a few tests on that branch. Still need to do the Matern case, and maybe periodic as well.

For a use case on the batching, see hsgp_lvm.ipynb on this branch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants