From fcd0101bb9abbdacf08b7d0803dd81747dc17264 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sat, 12 Aug 2023 00:22:44 +0100 Subject: [PATCH 01/17] calculate cross covariance in parallel inference --- .../demos/lgssm_learning.py | 164 ++++++++++++++++++ dynamax/linear_gaussian_ssm/models.py | 24 ++- .../linear_gaussian_ssm/parallel_inference.py | 27 ++- 3 files changed, 204 insertions(+), 11 deletions(-) create mode 100644 dynamax/linear_gaussian_ssm/demos/lgssm_learning.py diff --git a/dynamax/linear_gaussian_ssm/demos/lgssm_learning.py b/dynamax/linear_gaussian_ssm/demos/lgssm_learning.py new file mode 100644 index 00000000..d6a68dd3 --- /dev/null +++ b/dynamax/linear_gaussian_ssm/demos/lgssm_learning.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # MAP parameter estimation for an LG-SSM using EM and SGD +# +# +# + +# ## Setup + +# In[1]: + + +# In[2]: + + +from jax import numpy as jnp +import jax.random as jr +from matplotlib import pyplot as plt + +from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM +from dynamax.utils.utils import monotonically_increasing + + +# ## Data +# +# + +# In[3]: + + +state_dim=2 +emission_dim=5 +num_timesteps=200 +key = jr.PRNGKey(0) + +true_model = LinearGaussianConjugateSSM(state_dim, emission_dim) +key, key_root = jr.split(key) +true_params, param_props = true_model.initialize(key) + +key, key_root = jr.split(key) +true_states, emissions = true_model.sample(true_params, key, num_timesteps) + +# Plot the true states and emissions +fig, ax = plt.subplots(figsize=(10, 8)) +ax.plot(emissions + 3 * jnp.arange(emission_dim)) +ax.set_ylabel("data") +ax.set_xlabel("time") +ax.set_xlim(0, num_timesteps - 1) + + +# ## Plot results + +# In[4]: + + +def plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions): + plt.figure() + plt.xlabel("iteration") + nsteps = len(marginal_lls) + plt.plot(marginal_lls, label="estimated") + true_logjoint = (true_model.log_prior(true_params) + true_model.marginal_log_prob(true_params, emissions)) + plt.axhline(true_logjoint, color = 'k', linestyle = ':', label="true") + plt.ylabel("marginal joint probability") + plt.legend() + + +# In[5]: + + +def plot_predictions(true_model, true_params, test_model, test_params, emissions): + smoothed_emissions, smoothed_emissions_std = test_model.posterior_predictive(test_params, emissions) + + spc = 3 + plt.figure(figsize=(10, 4)) + for i in range(emission_dim): + plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None) + ln = plt.plot(smoothed_emissions[:, i] + spc * i, + label="smoothed" if i == 0 else None)[0] + plt.fill_between( + jnp.arange(num_timesteps), + spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[i], + spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[i], + color=ln.get_color(), + alpha=0.25, + ) + plt.xlabel("time") + plt.xlim(0, num_timesteps - 1) + plt.ylabel("true and predicted emissions") + plt.legend() + plt.show() + + +# In[6]: + + +# Plot predictions from a random, untrained model + +test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) +key = jr.PRNGKey(42) +test_params, param_props = test_model.initialize(key) + +plot_predictions(true_model, true_params, test_model, test_params, emissions) + + +# ## Fit with EM + +# In[23]: + + +test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) +key = jr.PRNGKey(42) +test_params, param_props = test_model.initialize(key) +num_iters = 100 +test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters) + +assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2) + + +# In[7]: + + +test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) +key = jr.PRNGKey(42) +test_params, param_props = test_model.initialize(key) +num_iters = 100 +test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters) + +assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2) + + +# In[ ]: + + +plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions) +plot_predictions(true_model, true_params, test_model, test_params, emissions) + + +# ## Fit with SGD + +# In[ ]: + + +test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) +key = jr.PRNGKey(42) +num_iters = 100 +test_params, param_props = test_model.initialize(key) + +test_params, neg_marginal_lls = test_model.fit_sgd(test_params, param_props, emissions, num_epochs=num_iters * 20) +marginal_lls = -neg_marginal_lls * emissions.size + + +# In[ ]: + + +plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions) +plot_predictions(true_model, true_params, test_model, test_params, emissions) + + +# In[ ]: + + + + diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 8e88c6bf..39a10db6 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -12,6 +12,9 @@ from dynamax.ssm import SSM from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample +from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter +from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother +from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed from dynamax.parameters import ParameterProperties, ParameterSet @@ -205,7 +208,8 @@ def marginal_log_prob( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> Scalar: - filtered_posterior = lgssm_filter(params, emissions, inputs) + # filtered_posterior = lgssm_filter(params, emissions, inputs) + filtered_posterior = parallel_lgssm_filter(params, emissions) return filtered_posterior.marginal_loglik def filter( @@ -214,7 +218,8 @@ def filter( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMFiltered: - return lgssm_filter(params, emissions, inputs) + # return lgssm_filter(params, emissions, inputs) + return parallel_lgssm_filter(params, emissions) def smoother( self, @@ -222,7 +227,8 @@ def smoother( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMSmoothed: - return lgssm_smoother(params, emissions, inputs) + # return lgssm_smoother(params, emissions, inputs) + return parallel_lgssm_smoother(params, emissions) def posterior_sample( self, @@ -231,7 +237,8 @@ def posterior_sample( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> Float[Array, "ntime state_dim"]: - return lgssm_posterior_sample(key, params, emissions, inputs) + # return lgssm_posterior_sample(key, params, emissions, inputs) + return parallel_lgssm_posterior_sample(key, params, emissions) def posterior_predictive( self, @@ -250,7 +257,8 @@ def posterior_predictive( :posterior predictive means $\mathbb{E}[y_{t,d} \mid y_{1:T}]$ and standard deviations $\mathrm{std}[y_{t,d} \mid y_{1:T}]$ """ - posterior = lgssm_smoother(params, emissions, inputs) + # posterior = lgssm_smoother(params, emissions, inputs) + posterior = parallel_lgssm_smoother(params, emissions) H = params.emissions.weights b = params.emissions.bias R = params.emissions.cov @@ -275,7 +283,8 @@ def e_step( inputs = jnp.zeros((num_timesteps, 0)) # Run the smoother to get posterior expectations - posterior = lgssm_smoother(params, emissions, inputs) + # posterior = lgssm_smoother(params, emissions, inputs) + posterior = parallel_lgssm_smoother(params, emissions) # shorthand Ex = posterior.smoothed_means @@ -580,7 +589,8 @@ def lgssm_params_sample(rng, stats): def one_sample(_params, rng): rngs = jr.split(rng, 2) # Sample latent states - states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs) + # states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs) + states = parallel_lgssm_posterior_sample(rngs[0], _params, emissions) # Sample parameters _stats = sufficient_stats_from_sample(states) return lgssm_params_sample(rngs[1], _stats) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index ad21a814..7114f37d 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -248,16 +248,35 @@ def _operator(elem1, elem2): initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs) final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) - + G = initial_messages.E[:-1] + smoothed_means = final_messages.g + smoothed_covariances = final_messages.L + smoothed_cross_covariances = compute_smoothed_cross_covariances( + G, smoothed_means[:-1], smoothed_means[1:], smoothed_covariances[1:] + ) return PosteriorGSSMSmoothed( marginal_loglik=filtered_posterior.marginal_loglik, filtered_means=filtered_means, filtered_covariances=filtered_covs, - smoothed_means=final_messages.g, - smoothed_covariances=final_messages.L + smoothed_means=smoothed_means, + smoothed_covariances=smoothed_covariances, + smoothed_cross_covariances=smoothed_cross_covariances, ) +@vmap +def compute_smoothed_cross_covariances( + G: Float[Array, "state_dim state_dim"], + smoothed_mean: Float[Array, "state_dim"], + smoothed_mean_next: Float[Array, "state_dim"], + smoothed_cov_next: Float[Array, "state_dim state_dim"], +) -> Float[Array, "state_dim state_dim"]: + # Compute the smoothed expectation of z_t z_{t+1}^T + # This is precomputed + # G = psd_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T + return G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) + + #---------------------------------------------------------------------------# # Sampling # #---------------------------------------------------------------------------# @@ -310,4 +329,4 @@ def _operator(elem1, elem2): initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs) _, samples = lax.associative_scan(_operator, initial_messages, reverse=True) - return samples \ No newline at end of file + return samples From 57eb463c38c85158d20aa4983c64799ed39b28ca Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sat, 12 Aug 2023 17:39:58 +0100 Subject: [PATCH 02/17] add support for parallel inference and sampling --- dynamax/linear_gaussian_ssm/inference.py | 7 ++- dynamax/linear_gaussian_ssm/models.py | 34 +++++++++------ .../linear_gaussian_ssm/parallel_inference.py | 43 +++++++++++++------ 3 files changed, 54 insertions(+), 30 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 2db9a260..75e1ef7d 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -415,7 +415,11 @@ def _step(carry, t): # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) - return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs) + + return PosteriorGSSMFiltered( + marginal_loglik=ll, + filtered_means=filtered_means, + filtered_covariances=filtered_covs) @preprocess_args @@ -495,7 +499,6 @@ def lgssm_posterior_sample( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None, jitter: Optional[Scalar]=0 - ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 39a10db6..3c2a674c 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -64,6 +64,8 @@ class LinearGaussianSSM(SSM): :param input_dim: Dimensionality of input vector. Defaults to 0. :param has_dynamics_bias: Whether model contains an offset term $b$. Defaults to True. :param has_emissions_bias: Whether model contains an offset term $d$. Defaults to True. + :param use_parallel_inference: Whether parallel algorithms are used in filtering, smoothing + and sampling instead of sequential ones. Defaults to False. """ def __init__( @@ -73,12 +75,14 @@ def __init__( input_dim: int=0, has_dynamics_bias: bool=True, has_emissions_bias: bool=True + use_parallel_inference: bool=False, ): self.state_dim = state_dim self.emission_dim = emission_dim self.input_dim = input_dim self.has_dynamics_bias = has_dynamics_bias self.has_emissions_bias = has_emissions_bias + self.use_parallel_inference = use_parallel_inference @property def emission_shape(self): @@ -208,8 +212,7 @@ def marginal_log_prob( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> Scalar: - # filtered_posterior = lgssm_filter(params, emissions, inputs) - filtered_posterior = parallel_lgssm_filter(params, emissions) + filtered_posterior = self.filter(params, emissions, inputs) return filtered_posterior.marginal_loglik def filter( @@ -218,8 +221,10 @@ def filter( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMFiltered: - # return lgssm_filter(params, emissions, inputs) - return parallel_lgssm_filter(params, emissions) + if self.use_parallel_inference: + return parallel_lgssm_filter(params, emissions, inputs) + else: + return lgssm_filter(params, emissions, inputs) def smoother( self, @@ -227,8 +232,10 @@ def smoother( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> PosteriorGSSMSmoothed: - # return lgssm_smoother(params, emissions, inputs) - return parallel_lgssm_smoother(params, emissions) + if self.use_parallel_inference: + return parallel_lgssm_smoother(params, emissions, inputs) + else: + return lgssm_smoother(params, emissions, inputs) def posterior_sample( self, @@ -237,8 +244,10 @@ def posterior_sample( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> Float[Array, "ntime state_dim"]: - # return lgssm_posterior_sample(key, params, emissions, inputs) - return parallel_lgssm_posterior_sample(key, params, emissions) + if use_parallel_inference: + return parallel_lgssm_posterior_sample(key, params, emissions, inputs) + else: + return lgssm_posterior_sample(key, params, emissions, inputs) def posterior_predictive( self, @@ -257,8 +266,7 @@ def posterior_predictive( :posterior predictive means $\mathbb{E}[y_{t,d} \mid y_{1:T}]$ and standard deviations $\mathrm{std}[y_{t,d} \mid y_{1:T}]$ """ - # posterior = lgssm_smoother(params, emissions, inputs) - posterior = parallel_lgssm_smoother(params, emissions) + posterior = self.smoother(params, emissions, inputs) H = params.emissions.weights b = params.emissions.bias R = params.emissions.cov @@ -283,8 +291,7 @@ def e_step( inputs = jnp.zeros((num_timesteps, 0)) # Run the smoother to get posterior expectations - # posterior = lgssm_smoother(params, emissions, inputs) - posterior = parallel_lgssm_smoother(params, emissions) + posterior = self.smoother(params, emissions, inputs) # shorthand Ex = posterior.smoothed_means @@ -589,8 +596,7 @@ def lgssm_params_sample(rng, stats): def one_sample(_params, rng): rngs = jr.split(rng, 2) # Sample latent states - # states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs) - states = parallel_lgssm_posterior_sample(rngs[0], _params, emissions) + states = self.posterior_sample(rngs[0], _params, emissions, inputs) # Sample parameters _stats = sufficient_stats_from_sample(states) return lgssm_params_sample(rngs[1], _stats) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 7114f37d..fa94cabc 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -40,7 +40,7 @@ from jax.scipy.linalg import cho_solve, cho_factor from dynamax.utils.utils import symmetrize -from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM +from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM, preprocess_args def _get_params(x, dim, t): @@ -74,16 +74,20 @@ class FilterMessage(NamedTuple): logZ: Float[Array, "ntime"] -def _initialize_filtering_messages(params, emissions): +def _initialize_filtering_messages(params, emissions, inputs): """Preprocess observations to construct input for filtering assocative scan.""" - def _first_message(params, y): + def _first_message(params, y, u): H = _get_params(params.emissions.weights, 2, 0) R = _get_params(params.emissions.cov, 2, 0) + D = _get_params(params.emissions.input_weights, 2, t) d = _get_params(params.emissions.bias, 1, 0) m = params.initial.mean P = params.initial.cov + # Adjust the bias term accoding to the input + d = d + D @ u + S = H @ P @ H.T + R CF, low = cho_factor(S) K = cho_solve((CF, low), H @ P).T @@ -98,15 +102,21 @@ def _first_message(params, y): return A, b, C, J, eta, logZ - @partial(vmap, in_axes=(None, 0, 0)) - def _generic_message(params, y, t): + @partial(vmap, in_axes=(None, 0, 0, 0)) + def _generic_message(params, y, u, t): F = _get_params(params.dynamics.weights, 2, t) + B = _get_params(params.dynamics.input_weights, 2, t) Q = _get_params(params.dynamics.cov, 2, t) b = _get_params(params.dynamics.bias, 1, t) H = _get_params(params.emissions.weights, 2, t+1) R = _get_params(params.emissions.cov, 2, t+1) + D = _get_params(params.emissions.input_weights, 2, t) d = _get_params(params.emissions.bias, 1, t+1) + # Adjust the bias terms accoding to the input + d = d + D @ u + b = b + B @ u + S = H @ Q @ H.T + R CF, low = cho_factor(S) K = cho_solve((CF, low), H @ Q).T @@ -122,8 +132,8 @@ def _generic_message(params, y, t): return A, b, C, J, eta, logZ - A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0]) - At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1)) + A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0]) + At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], inputs[1:], jnp.arange(len(emissions)-1)) return FilterMessage( A=jnp.concatenate([A0[None], At]), @@ -136,9 +146,11 @@ def _generic_message(params, y, t): +@preprocess_args def lgssm_filter( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMFiltered: """A parallel version of the lgssm filtering algorithm. @@ -168,13 +180,13 @@ def _operator(elem1, elem2): logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1) return FilterMessage(A, b, C, J, eta, logZ) - initial_messages = _initialize_filtering_messages(params, emissions) + initial_messages = _initialize_filtering_messages(params, emissions, inputs) final_messages = lax.associative_scan(_operator, initial_messages) return PosteriorGSSMFiltered( + marginal_loglik=-final_messages.logZ[-1], filtered_means=final_messages.b, - filtered_covariances=final_messages.C, - marginal_loglik=-final_messages.logZ[-1]) + filtered_covariances=final_messages.C) #---------------------------------------------------------------------------# @@ -223,9 +235,11 @@ def _generic_message(params, m, P, t): ) +@preprocess_args def lgssm_smoother( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMSmoothed: """A parallel version of the lgssm smoothing algorithm. @@ -306,7 +320,8 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> Float[Array, "ntime state_dim"]: """A parallel version of the lgssm sampling algorithm. @@ -314,7 +329,7 @@ def lgssm_posterior_sample( Note: This function does not yet handle `inputs` to the system. """ - filtered_posterior = lgssm_filter(params, emissions) + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances From 792fc2c767b10d9e5b9788f8d94088c691172e18 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sat, 12 Aug 2023 23:05:24 +0100 Subject: [PATCH 03/17] fix emission logic of parallel filter --- dynamax/linear_gaussian_ssm/models.py | 4 +- .../linear_gaussian_ssm/parallel_inference.py | 45 +++++++------- .../parallel_inference_test.py | 59 ++++++++++++------- 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 3c2a674c..615a85bf 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -74,8 +74,8 @@ def __init__( emission_dim: int, input_dim: int=0, has_dynamics_bias: bool=True, - has_emissions_bias: bool=True - use_parallel_inference: bool=False, + has_emissions_bias: bool=True, + use_parallel_inference: bool=False ): self.state_dim = state_dim self.emission_dim = emission_dim diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index fa94cabc..dad301d1 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -34,13 +34,14 @@ from jax import vmap, lax from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN from jaxtyping import Array, Float -from typing import NamedTuple +from typing import NamedTuple, Optional from dynamax.types import PRNGKey from functools import partial from jax.scipy.linalg import cho_solve, cho_factor from dynamax.utils.utils import symmetrize -from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM, preprocess_args +from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM +from dynamax.linear_gaussian_ssm.inference import preprocess_args def _get_params(x, dim, t): @@ -80,12 +81,12 @@ def _initialize_filtering_messages(params, emissions, inputs): def _first_message(params, y, u): H = _get_params(params.emissions.weights, 2, 0) R = _get_params(params.emissions.cov, 2, 0) - D = _get_params(params.emissions.input_weights, 2, t) + D = _get_params(params.emissions.input_weights, 2, 0) d = _get_params(params.emissions.bias, 1, 0) m = params.initial.mean P = params.initial.cov - # Adjust the bias term accoding to the input + # Adjust the bias term accoding to the input d = d + D @ u S = H @ P @ H.T + R @@ -98,7 +99,7 @@ def _first_message(params, y, u): eta = jnp.zeros_like(b) J = jnp.eye(len(b)) - logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=H @ P @ H.T + R).log_prob(y) + logZ = -MVN(loc=H @ m + d, covariance_matrix=H @ P @ H.T + R).log_prob(y) return A, b, C, J, eta, logZ @@ -109,14 +110,16 @@ def _generic_message(params, y, u, t): Q = _get_params(params.dynamics.cov, 2, t) b = _get_params(params.dynamics.bias, 1, t) H = _get_params(params.emissions.weights, 2, t+1) - R = _get_params(params.emissions.cov, 2, t+1) D = _get_params(params.emissions.input_weights, 2, t) + R = _get_params(params.emissions.cov, 2, t+1) d = _get_params(params.emissions.bias, 1, t+1) - # Adjust the bias terms accoding to the input + # Adjust the bias terms accoding to the input d = d + D @ u b = b + B @ u + mu_y = H @ b + d + S = H @ Q @ H.T + R CF, low = cho_factor(S) K = cho_solve((CF, low), H @ Q).T @@ -128,7 +131,7 @@ def _generic_message(params, y, u, t): b = b + K @ (y - H @ b - d) C = symmetrize(Q - K @ H @ Q) - logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=S).log_prob(y) + logZ = -MVN(loc=mu_y, covariance_matrix=S).log_prob(y) return A, b, C, J, eta, logZ @@ -155,8 +158,6 @@ def lgssm_filter( """A parallel version of the lgssm filtering algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. - - Note: This function does not yet handle `inputs` to the system. """ @vmap def _operator(elem1, elem2): @@ -207,7 +208,7 @@ class SmoothMessage(NamedTuple): L: Float[Array, "ntime state_dim state_dim"] -def _initialize_smoothing_messages(params, filtered_means, filtered_covariances): +def _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances): """Preprocess filtering output to construct input for smoothing assocative scan.""" def _last_message(m, P): @@ -217,7 +218,12 @@ def _last_message(m, P): def _generic_message(params, m, P, t): F = _get_params(params.dynamics.weights, 2, t) Q = _get_params(params.dynamics.cov, 2, t) + B = _get_params(params.dynamics.input_weights, 2, t) b = _get_params(params.dynamics.bias, 1, t) + u = inputs[t] + + # Adjust the bias terms accoding to the input + b = b + B @ u CF, low = cho_factor(F @ P @ F.T + Q) E = cho_solve((CF, low), F @ P).T @@ -244,10 +250,8 @@ def lgssm_smoother( """A parallel version of the lgssm smoothing algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. - - Note: This function does not yet handle `inputs` to the system. """ - filtered_posterior = lgssm_filter(params, emissions) + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances @@ -260,7 +264,7 @@ def _operator(elem1, elem2): L = symmetrize(E2 @ L1 @ E2.T + L2) return E, g, L - initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs) + initial_messages = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covs) final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) G = initial_messages.E[:-1] smoothed_means = final_messages.g @@ -307,13 +311,13 @@ class SampleMessage(NamedTuple): h: Float[Array, "ntime state_dim"] -def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances): +def _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_covariances): """A parallel version of the lgssm sampling algorithm. Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`. """ - E, g, L = _initialize_smoothing_messages(params, filtered_means, filtered_covariances) + E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) return SampleMessage(E=E, h=MVN(g, L).sample(seed=key)) @@ -326,9 +330,10 @@ def lgssm_posterior_sample( """A parallel version of the lgssm sampling algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. - - Note: This function does not yet handle `inputs` to the system. """ + num_timesteps = len(emissions) + inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances @@ -342,6 +347,6 @@ def _operator(elem1, elem2): h = E2 @ h1 + h2 return E, h - initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs) + initial_messages = _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_covs) _, samples = lax.associative_scan(_operator, initial_messages, reverse=True) return samples diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 260c86ec..7b5c05bc 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -21,27 +21,37 @@ def allclose(x,y, atol=1e-2): def make_static_lgssm_params(): + latent_dim = 4 + observation_dim = 2 + input_dim = 3 + dt = 0.1 F = jnp.eye(4) + dt * jnp.eye(4, k=2) + b = 0.1 * jnp.ones(4) Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], [dt**2/2, dt]]), jnp.eye(2)) H = jnp.eye(2, 4) + d = jnp.ones(2) R = 0.5 ** 2 * jnp.eye(2) - μ0 = jnp.array([0.,0.,1.,-1.]) + μ0 = jnp.array([0.,1.,1.,-1.]) Σ0 = jnp.eye(4) - latent_dim = 4 - observation_dim = 2 + B = jnp.eye(latent_dim, input_dim) * 0 + D = jnp.eye(observation_dim, input_dim) - lgssm = LinearGaussianSSM(latent_dim, observation_dim) + lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize(jr.PRNGKey(0), initial_mean=μ0, initial_covariance= Σ0, dynamics_weights=F, + dynamics_input_weights=B, + dynamics_bias=b, dynamics_covariance=Q, emission_weights=H, + emission_input_weights=D, + emission_bias=d, emission_covariance=R) return params, lgssm @@ -49,6 +59,7 @@ def make_static_lgssm_params(): def make_dynamic_lgssm_params(num_timesteps): latent_dim = 4 observation_dim = 2 + input_dim = 2 key = jr.PRNGKey(0) key, key_f, key_r, key_init = jr.split(key, 4) @@ -70,13 +81,18 @@ def make_dynamic_lgssm_params(num_timesteps): μ0 = jnp.array([0.,0.,1.,-1.]) Σ0 = jnp.eye(latent_dim) - lgssm = LinearGaussianSSM(latent_dim, observation_dim) + B = jnp.eye(latent_dim, input_dim) + D = jnp.eye(observation_dim, input_dim) + + lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize(key_init, initial_mean=μ0, initial_covariance=Σ0, dynamics_weights=F, + dynamics_input_weights=B, dynamics_covariance=Q, emission_weights=H, + emission_input_weights=D, emission_covariance=R) return params, lgssm @@ -85,13 +101,14 @@ class TestParallelLGSSMSmoother: """ Compare parallel and serial lgssm smoothing implementations.""" num_timesteps = 50 - key = jr.PRNGKey(1) + keys = jr.split(jr.PRNGKey(1), 2) params, lgssm = make_static_lgssm_params() - _, emissions = lgssm_joint_sample(params, key, num_timesteps) + inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - serial_posterior = serial_lgssm_smoother(params, emissions) - parallel_posterior = parallel_lgssm_smoother(params, emissions) + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) @@ -117,13 +134,14 @@ class TestTimeVaryingParallelLGSSMSmoother: Vary dynamics weights and observation covariances with time. """ num_timesteps = 50 - key = jr.PRNGKey(1) + keys = jr.split(jr.PRNGKey(1), 2) params, lgssm = make_dynamic_lgssm_params(num_timesteps) - _, emissions = lgssm_joint_sample(params, key, num_timesteps) + inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - serial_posterior = serial_lgssm_smoother(params, emissions) - parallel_posterior = parallel_lgssm_smoother(params, emissions) + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) @@ -146,20 +164,21 @@ class TestTimeVaryingParallelLGSSMSampler(): """Compare parallel and serial lgssm posterior sampling implementations in expectation.""" num_timesteps = 50 - key = jr.PRNGKey(1) + keys = jr.split(jr.PRNGKey(1), 2) params, lgssm = make_dynamic_lgssm_params(num_timesteps) - _, emissions = lgssm_joint_sample(params, key, num_timesteps) + inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) num_samples = 1000 serial_keys = jr.split(jr.PRNGKey(2), num_samples) parallel_keys = jr.split(jr.PRNGKey(3), num_samples) - serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0,None,None))( - serial_keys, params, emissions) + serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0, None, None, None))( + serial_keys, params, emissions, inputs) - parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - parallel_keys, params, emissions) + parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( + parallel_keys, params, emissions, inputs) def test_sampled_means(self): serial_mean = self.serial_samples.mean(axis=0) @@ -170,4 +189,4 @@ def test_sampled_covariances(self): # samples have shape (N, T, D): vmap over the T axis, calculate cov over N axis serial_cov = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.serial_samples) parallel_cov = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) - assert allclose(serial_cov, parallel_cov, atol=1e-1) \ No newline at end of file + assert allclose(serial_cov, parallel_cov, atol=1e-1) From 0874e03ecbb018a2dd0664298191d83714746287 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 02:51:41 +0100 Subject: [PATCH 04/17] saving while debugging filtering and smoothing --- .../linear_gaussian_ssm/parallel_inference.py | 7 ++-- .../parallel_inference_test.py | 33 +++++++++++++++---- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index dad301d1..70c08abd 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -118,7 +118,7 @@ def _generic_message(params, y, u, t): d = d + D @ u b = b + B @ u - mu_y = H @ b + d + mu_y = H @ b + d # mean of p(y_t|x_{t-1}=0) S = H @ Q @ H.T + R CF, low = cho_factor(S) @@ -217,12 +217,12 @@ def _last_message(m, P): @partial(vmap, in_axes=(None, 0, 0, 0)) def _generic_message(params, m, P, t): F = _get_params(params.dynamics.weights, 2, t) - Q = _get_params(params.dynamics.cov, 2, t) B = _get_params(params.dynamics.input_weights, 2, t) b = _get_params(params.dynamics.bias, 1, t) + Q = _get_params(params.dynamics.cov, 2, t) u = inputs[t] - # Adjust the bias terms accoding to the input + # Adjust the bias terms accoding to the input b = b + B @ u CF, low = cho_factor(F @ P @ F.T + Q) @@ -265,6 +265,7 @@ def _operator(elem1, elem2): return E, g, L initial_messages = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covs) + # breakpoint() final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) G = initial_messages.E[:-1] smoothed_means = final_messages.g diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 7b5c05bc..439d6d6f 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -27,19 +27,19 @@ def make_static_lgssm_params(): dt = 0.1 F = jnp.eye(4) + dt * jnp.eye(4, k=2) - b = 0.1 * jnp.ones(4) + b = 0.1 * jnp.arange(4) Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], [dt**2/2, dt]]), jnp.eye(2)) H = jnp.eye(2, 4) - d = jnp.ones(2) + d = 0.1 * jnp.ones(2) R = 0.5 ** 2 * jnp.eye(2) μ0 = jnp.array([0.,1.,1.,-1.]) Σ0 = jnp.eye(4) B = jnp.eye(latent_dim, input_dim) * 0 - D = jnp.eye(observation_dim, input_dim) + D = jnp.eye(observation_dim, input_dim) * 0 lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize(jr.PRNGKey(0), @@ -61,6 +61,8 @@ def make_dynamic_lgssm_params(num_timesteps): observation_dim = 2 input_dim = 2 + keys = jr.split(jr.PRNGKey(1), 100) + key = jr.PRNGKey(0) key, key_f, key_r, key_init = jr.split(key, 4) @@ -73,16 +75,21 @@ def make_dynamic_lgssm_params(num_timesteps): [dt**2/2, dt]]), jnp.eye(latent_dim // 2)) assert Q.shape[-1] == latent_dim - H = jnp.eye(observation_dim, latent_dim) + Q = Q[None] * jr.uniform(keys[3], (num_timesteps, 1, 1)) + + # H = jnp.eye(observation_dim, latent_dim) + H = jr.normal(keys[4], (num_timesteps, observation_dim, latent_dim)) r_scale = jr.normal(key_r, (num_timesteps,)) * 0.1 R = (r_scale**2)[:,None,None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) - μ0 = jnp.array([0.,0.,1.,-1.]) + μ0 = jnp.array([1.,-2.,1.,-1.]) Σ0 = jnp.eye(latent_dim) - B = jnp.eye(latent_dim, input_dim) - D = jnp.eye(observation_dim, input_dim) + B = jnp.eye(latent_dim, input_dim) * 0 + D = jnp.eye(observation_dim, input_dim) * 0 + b = jr.normal(keys[0], (num_timesteps, latent_dim)) * 0 + d = jr.normal(keys[1], (num_timesteps, observation_dim)) * 0 lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize(key_init, @@ -90,9 +97,11 @@ def make_dynamic_lgssm_params(num_timesteps): initial_covariance=Σ0, dynamics_weights=F, dynamics_input_weights=B, + dynamics_bias=b, dynamics_covariance=Q, emission_weights=H, emission_input_weights=D, + emission_bias=d, emission_covariance=R) return params, lgssm @@ -122,6 +131,11 @@ def test_smoothed_means(self): def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) + def test_smoothed_cross_covariances(self): + assert allclose( + self.serial_posterior.smoothed_cross_covariances, + self.parallel_posterior.smoothed_cross_covariances) + def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik) @@ -155,6 +169,11 @@ def test_smoothed_means(self): def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) + def test_smoothed_cross_covariances(self): + assert allclose( + self.serial_posterior.smoothed_cross_covariances, + self.parallel_posterior.smoothed_cross_covariances) + def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=1e-1) From 44e13e8e55bcda975107aeaf122d2ba921b3aaf2 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 13:31:07 +0100 Subject: [PATCH 05/17] fix sampling logic --- dynamax/linear_gaussian_ssm/inference.py | 22 +++++++------- .../linear_gaussian_ssm/parallel_inference.py | 30 +++++++++++-------- .../parallel_inference_test.py | 16 ++++++---- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 75e1ef7d..89323ece 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -305,8 +305,8 @@ def lgssm_joint_sample( params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) - def _sample_transition(key, F, B, b, Q, x_tm1, u): - mean = F @ x_tm1 + B @ u + b + def _sample_transition(key, F, B, b, Q, x, u): + mean = F @ x + B @ u + b return MVN(mean, Q).sample(seed=key) def _sample_emission(key, H, D, d, R, x, u): @@ -327,22 +327,22 @@ def _sample_initial(key, params, inputs): initial_emission = _sample_emission(key2, H0, D0, d0, R0, initial_state, u0) return initial_state, initial_emission - def _step(prev_state, args): + def _step(state_prev, args): key, t, inpt = args key1, key2 = jr.split(key, 2) # Shorthand: get parameters and inputs for time index t - F = _get_params(params.dynamics.weights, 2, t) - B = _get_params(params.dynamics.input_weights, 2, t) - b = _get_params(params.dynamics.bias, 1, t) - Q = _get_params(params.dynamics.cov, 2, t) + F_prev = _get_params(params.dynamics.weights, 2, t - 1) + B_prev = _get_params(params.dynamics.input_weights, 2, t - 1) + b_prev = _get_params(params.dynamics.bias, 1, t - 1) + Q_prev = _get_params(params.dynamics.cov, 2, t - 1) H = _get_params(params.emissions.weights, 2, t) D = _get_params(params.emissions.input_weights, 2, t) d = _get_params(params.emissions.bias, 1, t) R = _get_params(params.emissions.cov, 2, t) # Sample from transition and emission distributions - state = _sample_transition(key1, F, B, b, Q, prev_state, inpt) + state = _sample_transition(key1, F_prev, B_prev, b_prev, Q_prev, state_prev, inpt) emission = _sample_emission(key2, H, D, d, R, state, inpt) return state, (state, emission) @@ -410,11 +410,13 @@ def _step(carry, t): # Predict the next state pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u) - return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) + return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov, ll) # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) - (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + (ll, _, _), (filtered_means, filtered_covs, lik) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + + # breakpoint() return PosteriorGSSMFiltered( marginal_loglik=ll, diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 70c08abd..8d2769af 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -183,6 +183,7 @@ def _operator(elem1, elem2): initial_messages = _initialize_filtering_messages(params, emissions, inputs) final_messages = lax.associative_scan(_operator, initial_messages) + # breakpoint() return PosteriorGSSMFiltered( marginal_loglik=-final_messages.logZ[-1], @@ -203,6 +204,7 @@ class SmoothMessage(NamedTuple): g: P(z_i | y_{1:j}, z_{j+1}) bias. L: P(z_i | y_{1:j}, z_{j+1}) covariance. """ + G: Float[Array, "ntime state_dim state_dim"] E: Float[Array, "ntime state_dim state_dim"] g: Float[Array, "ntime state_dim"] L: Float[Array, "ntime state_dim state_dim"] @@ -212,7 +214,7 @@ def _initialize_smoothing_messages(params, inputs, filtered_means, filtered_cova """Preprocess filtering output to construct input for smoothing assocative scan.""" def _last_message(m, P): - return jnp.zeros_like(P), m, P + return jnp.zeros_like(P), jnp.zeros_like(P), m, P @partial(vmap, in_axes=(None, 0, 0, 0)) def _generic_message(params, m, P, t): @@ -226,15 +228,18 @@ def _generic_message(params, m, P, t): b = b + B @ u CF, low = cho_factor(F @ P @ F.T + Q) - E = cho_solve((CF, low), F @ P).T - g = m - E @ (F @ m + b) - L = symmetrize(P - E @ F @ P) - return E, g, L + G = cho_solve((CF, low), F @ P).T + g = m - G @ (F @ m + b) + L = symmetrize(P - G @ F @ P) + E = jnp.linalg.solve(Q, F @ L).T + gg = m - G @ (F @ m) + E @ b + return G, G, g, L - En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) - Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) + Gn, En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) + Gt, Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) return SmoothMessage( + G=jnp.concatenate([Gt, Gn[None]]), E=jnp.concatenate([Et, En[None]]), g=jnp.concatenate([gt, gn[None]]), L=jnp.concatenate([Lt, Ln[None]]) @@ -257,17 +262,16 @@ def lgssm_smoother( @vmap def _operator(elem1, elem2): - E1, g1, L1 = elem1 - E2, g2, L2 = elem2 + _, E1, g1, L1 = elem1 + _, E2, g2, L2 = elem2 E = E2 @ E1 g = E2 @ g1 + g2 L = symmetrize(E2 @ L1 @ E2.T + L2) - return E, g, L + return _, E, g, L initial_messages = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covs) - # breakpoint() final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) - G = initial_messages.E[:-1] + G = initial_messages.G[:-1] smoothed_means = final_messages.g smoothed_covariances = final_messages.L smoothed_cross_covariances = compute_smoothed_cross_covariances( @@ -318,7 +322,7 @@ def _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_ Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`. """ - E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) + _, E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) return SampleMessage(E=E, h=MVN(g, L).sample(seed=key)) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 439d6d6f..93e99fcf 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -11,6 +11,11 @@ from dynamax.linear_gaussian_ssm import parallel_lgssm_posterior_sample + +from jax.config import config; config.update("jax_enable_x64", True) + + + def allclose(x,y, atol=1e-2): m = jnp.abs(jnp.max(x-y)) if m > atol: @@ -86,10 +91,10 @@ def make_dynamic_lgssm_params(num_timesteps): μ0 = jnp.array([1.,-2.,1.,-1.]) Σ0 = jnp.eye(latent_dim) - B = jnp.eye(latent_dim, input_dim) * 0 - D = jnp.eye(observation_dim, input_dim) * 0 - b = jr.normal(keys[0], (num_timesteps, latent_dim)) * 0 - d = jr.normal(keys[1], (num_timesteps, observation_dim)) * 0 + B = jnp.eye(latent_dim, input_dim) + D = jnp.eye(observation_dim, input_dim) + b = jr.normal(keys[0], (num_timesteps, latent_dim)) + d = jr.normal(keys[1], (num_timesteps, observation_dim)) lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize(key_init, @@ -141,7 +146,6 @@ def test_marginal_loglik(self): - class TestTimeVaryingParallelLGSSMSmoother: """Compare parallel and serial time-varying lgssm smoothing implementations. @@ -157,6 +161,8 @@ class TestTimeVaryingParallelLGSSMSmoother: serial_posterior = serial_lgssm_smoother(params, emissions, inputs) parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + breakpoint() + def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) From 3f890215cad833fbe9ba57f9ad4d3a73b0435454 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 15:46:32 +0100 Subject: [PATCH 06/17] maybe fixed the parallel smoother with dynamic emission with inputs --- dynamax/linear_gaussian_ssm/inference.py | 5 +- .../linear_gaussian_ssm/parallel_inference.py | 46 ++++++++++--------- .../parallel_inference_test.py | 11 +++-- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 89323ece..c9918726 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -123,6 +123,7 @@ def _get_params(x, dim, t): return x[t] else: return x + _zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape) @@ -167,7 +168,7 @@ def _predict(m, S, F, B, b, Q, u): r"""Predict next mean and covariance under a linear Gaussian model. p(z_{t+1}) = int N(z_t \mid m, S) N(z_{t+1} \mid Fz_t + Bu + b, Q) - = N(z_{t+1} \mid Fm + Bu, F S F^T + Q) + = N(z_{t+1} \mid Fm + Bu + b, F S F^T + Q) Args: m (D_hid,): prior mean. @@ -414,7 +415,7 @@ def _step(carry, t): # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) - (ll, _, _), (filtered_means, filtered_covs, lik) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + (ll, _, _), (filtered_means, filtered_covs, cumll) = lax.scan(_step, carry, jnp.arange(num_timesteps)) # breakpoint() diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 8d2769af..fa97961c 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -103,40 +103,42 @@ def _first_message(params, y, u): return A, b, C, J, eta, logZ - @partial(vmap, in_axes=(None, 0, 0, 0)) - def _generic_message(params, y, u, t): - F = _get_params(params.dynamics.weights, 2, t) - B = _get_params(params.dynamics.input_weights, 2, t) - Q = _get_params(params.dynamics.cov, 2, t) - b = _get_params(params.dynamics.bias, 1, t) - H = _get_params(params.emissions.weights, 2, t+1) + @partial(vmap, in_axes=(None, 0, 0)) + def _generic_message(params, y, t): + F_prev = _get_params(params.dynamics.weights, 2, t - 1) + B_prev = _get_params(params.dynamics.input_weights, 2, t - 1) + Q_prev = _get_params(params.dynamics.cov, 2, t - 1) + b_prev = _get_params(params.dynamics.bias, 1, t - 1) + H = _get_params(params.emissions.weights, 2, t) D = _get_params(params.emissions.input_weights, 2, t) - R = _get_params(params.emissions.cov, 2, t+1) - d = _get_params(params.emissions.bias, 1, t+1) + R = _get_params(params.emissions.cov, 2, t) + d = _get_params(params.emissions.bias, 1, t) + u_prev = inputs[t - 1] + u = inputs[t] - # Adjust the bias terms accoding to the input + # Adjust the b_previas terms accoding to the input d = d + D @ u - b = b + B @ u + b_prev = b_prev + B_prev @ u_prev - mu_y = H @ b + d # mean of p(y_t|x_{t-1}=0) + mu_y = H @ b_prev + d # mean of p(y_t|x_{t-1}=0) - S = H @ Q @ H.T + R - CF, low = cho_factor(S) - K = cho_solve((CF, low), H @ Q).T + S = H @ Q_prev @ H.T + R + CF_prev, low = cho_factor(S) + K = cho_solve((CF_prev, low), H @ Q_prev).T - eta = F.T @ H.T @ cho_solve((CF, low), y - H @ b - d) - J = symmetrize(F.T @ H.T @ cho_solve((CF, low), H @ F)) + eta = F_prev.T @ H.T @ cho_solve((CF_prev, low), y - H @ b_prev - d) + J = symmetrize(F_prev.T @ H.T @ cho_solve((CF_prev, low), H @ F_prev)) - A = F - K @ H @ F - b = b + K @ (y - H @ b - d) - C = symmetrize(Q - K @ H @ Q) + A = F_prev - K @ H @ F_prev + b_prev = b_prev + K @ (y - H @ b_prev - d) + C = symmetrize(Q_prev - K @ H @ Q_prev) logZ = -MVN(loc=mu_y, covariance_matrix=S).log_prob(y) - return A, b, C, J, eta, logZ + return A, b_prev, C, J, eta, logZ A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0]) - At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], inputs[1:], jnp.arange(len(emissions)-1)) + At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(1, len(emissions))) return FilterMessage( A=jnp.concatenate([A0[None], At]), diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 93e99fcf..b0e662b9 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -91,8 +91,9 @@ def make_dynamic_lgssm_params(num_timesteps): μ0 = jnp.array([1.,-2.,1.,-1.]) Σ0 = jnp.eye(latent_dim) - B = jnp.eye(latent_dim, input_dim) - D = jnp.eye(observation_dim, input_dim) + B = jnp.eye(latent_dim, input_dim)[None] + 0.1 * jr.normal(keys[6], (num_timesteps, latent_dim, input_dim)) + # B = B * 0 + D = jnp.eye(observation_dim, input_dim)[None] + 0.1 * jr.normal(keys[7], (num_timesteps, observation_dim, input_dim)) b = jr.normal(keys[0], (num_timesteps, latent_dim)) d = jr.normal(keys[1], (num_timesteps, observation_dim)) @@ -119,10 +120,10 @@ class TestParallelLGSSMSmoother: params, lgssm = make_static_lgssm_params() inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) - _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) + # _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - serial_posterior = serial_lgssm_smoother(params, emissions, inputs) - parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + # serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + # parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) From 221d147fe52be49a21eccad656c1ffcb7c0b7e00 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 18:17:04 +0100 Subject: [PATCH 07/17] get all tests to work --- dynamax/linear_gaussian_ssm/inference.py | 11 ++-- dynamax/linear_gaussian_ssm/models.py | 10 ++-- dynamax/linear_gaussian_ssm/models_test.py | 6 ++- .../linear_gaussian_ssm/parallel_inference.py | 31 +++++------ .../parallel_inference_test.py | 54 ++++++++----------- 5 files changed, 49 insertions(+), 63 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index c9918726..f3663e15 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -411,18 +411,13 @@ def _step(carry, t): # Predict the next state pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u) - return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov, ll) + return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) - (ll, _, _), (filtered_means, filtered_covs, cumll) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) - # breakpoint() - - return PosteriorGSSMFiltered( - marginal_loglik=ll, - filtered_means=filtered_means, - filtered_covariances=filtered_covs) + return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs) @preprocess_args diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 615a85bf..d3ee1583 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -11,7 +11,9 @@ from typing_extensions import Protocol from dynamax.ssm import SSM -from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample +from dynamax.linear_gaussian_ssm.inference import lgssm_filter as serial_lgssm_filter +from dynamax.linear_gaussian_ssm.inference import lgssm_smoother as serial_lgssm_smoother +from dynamax.linear_gaussian_ssm.inference import lgssm_posterior_sample as serial_lgssm_posterior_sample from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample @@ -224,7 +226,7 @@ def filter( if self.use_parallel_inference: return parallel_lgssm_filter(params, emissions, inputs) else: - return lgssm_filter(params, emissions, inputs) + return serial_lgssm_filter(params, emissions, inputs) def smoother( self, @@ -235,7 +237,7 @@ def smoother( if self.use_parallel_inference: return parallel_lgssm_smoother(params, emissions, inputs) else: - return lgssm_smoother(params, emissions, inputs) + return serial_lgssm_smoother(params, emissions, inputs) def posterior_sample( self, @@ -247,7 +249,7 @@ def posterior_sample( if use_parallel_inference: return parallel_lgssm_posterior_sample(key, params, emissions, inputs) else: - return lgssm_posterior_sample(key, params, emissions, inputs) + return serial_lgssm_posterior_sample(key, params, emissions, inputs) def posterior_predictive( self, diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index c4394858..6ac68ddd 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -8,8 +8,10 @@ NUM_TIMESTEPS = 100 CONFIGS = [ - (LinearGaussianSSM, dict(state_dim=2, emission_dim=10), None), - (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10), None), + (LinearGaussianSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=False), None), + (LinearGaussianSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=True), None), + (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=False), None), + (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=True), None), ] @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index fa97961c..f2548018 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -116,7 +116,7 @@ def _generic_message(params, y, t): u_prev = inputs[t - 1] u = inputs[t] - # Adjust the b_previas terms accoding to the input + # Adjust the bias terms accoding to the input d = d + D @ u b_prev = b_prev + B_prev @ u_prev @@ -185,7 +185,6 @@ def _operator(elem1, elem2): initial_messages = _initialize_filtering_messages(params, emissions, inputs) final_messages = lax.associative_scan(_operator, initial_messages) - # breakpoint() return PosteriorGSSMFiltered( marginal_loglik=-final_messages.logZ[-1], @@ -206,7 +205,6 @@ class SmoothMessage(NamedTuple): g: P(z_i | y_{1:j}, z_{j+1}) bias. L: P(z_i | y_{1:j}, z_{j+1}) covariance. """ - G: Float[Array, "ntime state_dim state_dim"] E: Float[Array, "ntime state_dim state_dim"] g: Float[Array, "ntime state_dim"] L: Float[Array, "ntime state_dim state_dim"] @@ -216,7 +214,7 @@ def _initialize_smoothing_messages(params, inputs, filtered_means, filtered_cova """Preprocess filtering output to construct input for smoothing assocative scan.""" def _last_message(m, P): - return jnp.zeros_like(P), jnp.zeros_like(P), m, P + return jnp.zeros_like(P), m, P @partial(vmap, in_axes=(None, 0, 0, 0)) def _generic_message(params, m, P, t): @@ -230,18 +228,15 @@ def _generic_message(params, m, P, t): b = b + B @ u CF, low = cho_factor(F @ P @ F.T + Q) - G = cho_solve((CF, low), F @ P).T - g = m - G @ (F @ m + b) - L = symmetrize(P - G @ F @ P) - E = jnp.linalg.solve(Q, F @ L).T - gg = m - G @ (F @ m) + E @ b - return G, G, g, L + E = cho_solve((CF, low), F @ P).T + g = m - E @ (F @ m + b) + L = symmetrize(P - E @ F @ P) + return E, g, L - Gn, En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) - Gt, Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) + En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) + Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) return SmoothMessage( - G=jnp.concatenate([Gt, Gn[None]]), E=jnp.concatenate([Et, En[None]]), g=jnp.concatenate([gt, gn[None]]), L=jnp.concatenate([Lt, Ln[None]]) @@ -264,16 +259,16 @@ def lgssm_smoother( @vmap def _operator(elem1, elem2): - _, E1, g1, L1 = elem1 - _, E2, g2, L2 = elem2 + E1, g1, L1 = elem1 + E2, g2, L2 = elem2 E = E2 @ E1 g = E2 @ g1 + g2 L = symmetrize(E2 @ L1 @ E2.T + L2) - return _, E, g, L + return E, g, L initial_messages = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covs) final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) - G = initial_messages.G[:-1] + G = initial_messages.E[:-1] smoothed_means = final_messages.g smoothed_covariances = final_messages.L smoothed_cross_covariances = compute_smoothed_cross_covariances( @@ -324,7 +319,7 @@ def _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_ Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`. """ - _, E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) + E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) return SampleMessage(E=E, h=MVN(g, L).sample(seed=key)) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index b0e662b9..bcf7f30e 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -15,14 +15,7 @@ from jax.config import config; config.update("jax_enable_x64", True) - -def allclose(x,y, atol=1e-2): - m = jnp.abs(jnp.max(x-y)) - if m > atol: - print(m) - return False - else: - return True +allclose = partial(jnp.allclose, atol=1e-2, rtol=1e-2) def make_static_lgssm_params(): @@ -92,23 +85,22 @@ def make_dynamic_lgssm_params(num_timesteps): Σ0 = jnp.eye(latent_dim) B = jnp.eye(latent_dim, input_dim)[None] + 0.1 * jr.normal(keys[6], (num_timesteps, latent_dim, input_dim)) - # B = B * 0 D = jnp.eye(observation_dim, input_dim)[None] + 0.1 * jr.normal(keys[7], (num_timesteps, observation_dim, input_dim)) b = jr.normal(keys[0], (num_timesteps, latent_dim)) d = jr.normal(keys[1], (num_timesteps, observation_dim)) lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize(key_init, - initial_mean=μ0, - initial_covariance=Σ0, - dynamics_weights=F, - dynamics_input_weights=B, - dynamics_bias=b, - dynamics_covariance=Q, - emission_weights=H, - emission_input_weights=D, - emission_bias=d, - emission_covariance=R) + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_input_weights=B, + dynamics_bias=b, + dynamics_covariance=Q, + emission_weights=H, + emission_input_weights=D, + emission_bias=d, + emission_covariance=R) return params, lgssm @@ -120,10 +112,10 @@ class TestParallelLGSSMSmoother: params, lgssm = make_static_lgssm_params() inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) - # _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - # serial_posterior = serial_lgssm_smoother(params, emissions, inputs) - # parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) @@ -138,9 +130,10 @@ def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) def test_smoothed_cross_covariances(self): - assert allclose( - self.serial_posterior.smoothed_cross_covariances, - self.parallel_posterior.smoothed_cross_covariances) + x = self.serial_posterior.smoothed_cross_covariances + y = self.parallel_posterior.smoothed_cross_covariances + matrix_norm_rel_diff = jnp.linalg.norm(x - y, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik) @@ -162,8 +155,6 @@ class TestTimeVaryingParallelLGSSMSmoother: serial_posterior = serial_lgssm_smoother(params, emissions, inputs) parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) - breakpoint() - def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) @@ -177,12 +168,13 @@ def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) def test_smoothed_cross_covariances(self): - assert allclose( - self.serial_posterior.smoothed_cross_covariances, - self.parallel_posterior.smoothed_cross_covariances) + x = self.serial_posterior.smoothed_cross_covariances + y = self.parallel_posterior.smoothed_cross_covariances + matrix_norm_rel_diff = jnp.linalg.norm(x - y, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) def test_marginal_loglik(self): - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=1e-1) + assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, rtol=2e-2) From 17724cfe20358e4fc3551f63df89f9469ce52bbc Mon Sep 17 00:00:00 2001 From: Kei Ishikawa <30857855+kstoneriv3@users.noreply.github.com> Date: Sun, 13 Aug 2023 18:49:42 +0100 Subject: [PATCH 08/17] Delete lgssm_learning.py --- .../demos/lgssm_learning.py | 164 ------------------ 1 file changed, 164 deletions(-) delete mode 100644 dynamax/linear_gaussian_ssm/demos/lgssm_learning.py diff --git a/dynamax/linear_gaussian_ssm/demos/lgssm_learning.py b/dynamax/linear_gaussian_ssm/demos/lgssm_learning.py deleted file mode 100644 index d6a68dd3..00000000 --- a/dynamax/linear_gaussian_ssm/demos/lgssm_learning.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -# # MAP parameter estimation for an LG-SSM using EM and SGD -# -# -# - -# ## Setup - -# In[1]: - - -# In[2]: - - -from jax import numpy as jnp -import jax.random as jr -from matplotlib import pyplot as plt - -from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM -from dynamax.utils.utils import monotonically_increasing - - -# ## Data -# -# - -# In[3]: - - -state_dim=2 -emission_dim=5 -num_timesteps=200 -key = jr.PRNGKey(0) - -true_model = LinearGaussianConjugateSSM(state_dim, emission_dim) -key, key_root = jr.split(key) -true_params, param_props = true_model.initialize(key) - -key, key_root = jr.split(key) -true_states, emissions = true_model.sample(true_params, key, num_timesteps) - -# Plot the true states and emissions -fig, ax = plt.subplots(figsize=(10, 8)) -ax.plot(emissions + 3 * jnp.arange(emission_dim)) -ax.set_ylabel("data") -ax.set_xlabel("time") -ax.set_xlim(0, num_timesteps - 1) - - -# ## Plot results - -# In[4]: - - -def plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions): - plt.figure() - plt.xlabel("iteration") - nsteps = len(marginal_lls) - plt.plot(marginal_lls, label="estimated") - true_logjoint = (true_model.log_prior(true_params) + true_model.marginal_log_prob(true_params, emissions)) - plt.axhline(true_logjoint, color = 'k', linestyle = ':', label="true") - plt.ylabel("marginal joint probability") - plt.legend() - - -# In[5]: - - -def plot_predictions(true_model, true_params, test_model, test_params, emissions): - smoothed_emissions, smoothed_emissions_std = test_model.posterior_predictive(test_params, emissions) - - spc = 3 - plt.figure(figsize=(10, 4)) - for i in range(emission_dim): - plt.plot(emissions[:, i] + spc * i, "--k", label="observed" if i == 0 else None) - ln = plt.plot(smoothed_emissions[:, i] + spc * i, - label="smoothed" if i == 0 else None)[0] - plt.fill_between( - jnp.arange(num_timesteps), - spc * i + smoothed_emissions[:, i] - 2 * smoothed_emissions_std[i], - spc * i + smoothed_emissions[:, i] + 2 * smoothed_emissions_std[i], - color=ln.get_color(), - alpha=0.25, - ) - plt.xlabel("time") - plt.xlim(0, num_timesteps - 1) - plt.ylabel("true and predicted emissions") - plt.legend() - plt.show() - - -# In[6]: - - -# Plot predictions from a random, untrained model - -test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) -key = jr.PRNGKey(42) -test_params, param_props = test_model.initialize(key) - -plot_predictions(true_model, true_params, test_model, test_params, emissions) - - -# ## Fit with EM - -# In[23]: - - -test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) -key = jr.PRNGKey(42) -test_params, param_props = test_model.initialize(key) -num_iters = 100 -test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters) - -assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2) - - -# In[7]: - - -test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) -key = jr.PRNGKey(42) -test_params, param_props = test_model.initialize(key) -num_iters = 100 -test_params, marginal_lls = test_model.fit_em(test_params, param_props, emissions, num_iters=num_iters) - -assert monotonically_increasing(marginal_lls, atol=1e-2, rtol=1e-2) - - -# In[ ]: - - -plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions) -plot_predictions(true_model, true_params, test_model, test_params, emissions) - - -# ## Fit with SGD - -# In[ ]: - - -test_model = LinearGaussianConjugateSSM(state_dim, emission_dim) -key = jr.PRNGKey(42) -num_iters = 100 -test_params, param_props = test_model.initialize(key) - -test_params, neg_marginal_lls = test_model.fit_sgd(test_params, param_props, emissions, num_epochs=num_iters * 20) -marginal_lls = -neg_marginal_lls * emissions.size - - -# In[ ]: - - -plot_learning_curve(marginal_lls, true_model, true_params, test_model, test_params, emissions) -plot_predictions(true_model, true_params, test_model, test_params, emissions) - - -# In[ ]: - - - - From b16c1d911f5c1b68e59c8c7db23e67baec9c4de2 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 20:46:16 +0100 Subject: [PATCH 09/17] undo modification of a notebook for PR --- docs/notebooks/linear_gaussian_ssm/kf_linreg.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/notebooks/linear_gaussian_ssm/kf_linreg.ipynb b/docs/notebooks/linear_gaussian_ssm/kf_linreg.ipynb index ea352631..ea705f19 100644 --- a/docs/notebooks/linear_gaussian_ssm/kf_linreg.ipynb +++ b/docs/notebooks/linear_gaussian_ssm/kf_linreg.ipynb @@ -10,7 +10,7 @@ "\n", "We perform sequential (recursive) Bayesian inference for the parameters of a linear regression model\n", "using the Kalman filter. (This algorithm is also known as recursive least squares.)\n", - "To do this, we treat the parameers of the model as the unknown hidden states.\n", + "To do this, we treat the parameters of the model as the unknown hidden states.\n", "We assume that these are constant over time.\n", "The graphical model is shown below.\n", "\n", From c64ddd6b3887d2677875afe9f02e477909b57d9e Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 21:44:53 +0100 Subject: [PATCH 10/17] apply black --- dynamax/linear_gaussian_ssm/inference.py | 138 +++++++++++------- .../linear_gaussian_ssm/parallel_inference.py | 74 +++++----- .../parallel_inference_test.py | 108 +++++++------- 3 files changed, 184 insertions(+), 136 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index f3663e15..cb693333 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -12,6 +12,7 @@ from dynamax.parameters import ParameterProperties from dynamax.types import PRNGKey, Scalar + class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution @@ -41,10 +42,27 @@ class ParamsLGSSMDynamics(NamedTuple): :param cov: dynamics covariance $Q$ """ - weights: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], ParameterProperties] - bias: Union[Float[Array, "state_dim"], Float[Array, "ntime state_dim"], ParameterProperties] - input_weights: Union[Float[Array, "state_dim input_dim"], Float[Array, "ntime state_dim input_dim"], ParameterProperties] - cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties] + weights: Union[ + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], + ParameterProperties, + ] + bias: Union[ + Float[Array, "state_dim"], + Float[Array, "ntime state_dim"], + ParameterProperties, + ] + input_weights: Union[ + Float[Array, "state_dim input_dim"], + Float[Array, "ntime state_dim input_dim"], + ParameterProperties, + ] + cov: Union[ + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], + Float[Array, "state_dim_triu"], + ParameterProperties, + ] class ParamsLGSSMEmissions(NamedTuple): @@ -60,11 +78,27 @@ class ParamsLGSSMEmissions(NamedTuple): :param cov: emission covariance $R$ """ - weights: Union[Float[Array, "emission_dim state_dim"], Float[Array, "ntime emission_dim state_dim"], ParameterProperties] - bias: Union[Float[Array, "emission_dim"], Float[Array, "ntime emission_dim"], ParameterProperties] - input_weights: Union[Float[Array, "emission_dim input_dim"], Float[Array, "ntime emission_dim input_dim"], ParameterProperties] - cov: Union[Float[Array, "emission_dim emission_dim"], Float[Array, "ntime emission_dim emission_dim"], Float[Array, "emission_dim_triu"], ParameterProperties] - + weights: Union[ + Float[Array, "emission_dim state_dim"], + Float[Array, "ntime emission_dim state_dim"], + ParameterProperties, + ] + bias: Union[ + Float[Array, "emission_dim"], + Float[Array, "ntime emission_dim"], + ParameterProperties, + ] + input_weights: Union[ + Float[Array, "emission_dim input_dim"], + Float[Array, "ntime emission_dim input_dim"], + ParameterProperties, + ] + cov: Union[ + Float[Array, "emission_dim emission_dim"], + Float[Array, "ntime emission_dim emission_dim"], + Float[Array, "emission_dim_triu"], + ParameterProperties, + ] class ParamsLGSSM(NamedTuple): @@ -124,42 +158,44 @@ def _get_params(x, dim, t): else: return x + _zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape) -def make_lgssm_params(initial_mean, - initial_cov, - dynamics_weights, - dynamics_cov, - emissions_weights, - emissions_cov, - dynamics_bias=None, - dynamics_input_weights=None, - emissions_bias=None, - emissions_input_weights=None): +def make_lgssm_params( + initial_mean, + initial_cov, + dynamics_weights, + dynamics_cov, + emissions_weights, + emissions_cov, + dynamics_bias=None, + dynamics_input_weights=None, + emissions_bias=None, + emissions_input_weights=None, +): """Helper function to construct a ParamsLGSSM object from arguments.""" state_dim = len(initial_mean) emission_dim = emissions_cov.shape[-1] - input_dim = max(dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, - emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0) + input_dim = max( + dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, + emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0, + ) params = ParamsLGSSM( - initial=ParamsLGSSMInitial( - mean=initial_mean, - cov=initial_cov - ), + initial=ParamsLGSSMInitial(mean=initial_mean, cov=initial_cov), dynamics=ParamsLGSSMDynamics( weights=dynamics_weights, - bias=_zeros_if_none(dynamics_bias,state_dim), + bias=_zeros_if_none(dynamics_bias, state_dim), input_weights=_zeros_if_none(dynamics_input_weights, (state_dim, input_dim)), - cov=dynamics_cov + cov=dynamics_cov, ), emissions=ParamsLGSSMEmissions( weights=emissions_weights, bias=_zeros_if_none(emissions_bias, emission_dim), input_weights=_zeros_if_none(emissions_input_weights, (emission_dim, input_dim)), - cov=emissions_cov - ) + cov=emissions_cov, + ), ) return params @@ -249,20 +285,20 @@ def preprocess_params_and_inputs(params, num_timesteps, inputs): emissions_bias = _zeros_if_none(params.emissions.bias, (emission_dim,)) full_params = ParamsLGSSM( - initial=ParamsLGSSMInitial( - mean=params.initial.mean, - cov=params.initial.cov), + initial=ParamsLGSSMInitial(mean=params.initial.mean, cov=params.initial.cov), dynamics=ParamsLGSSMDynamics( weights=params.dynamics.weights, bias=dynamics_bias, input_weights=dynamics_input_weights, - cov=params.dynamics.cov), + cov=params.dynamics.cov, + ), emissions=ParamsLGSSMEmissions( weights=params.emissions.weights, bias=emissions_bias, input_weights=emissions_input_weights, - cov=params.emissions.cov) - ) + cov=params.emissions.cov, + ), + ) return full_params, inputs @@ -275,14 +311,15 @@ def wrapper(*args, **kwargs): # Extract the arguments by name bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - params = bound_args.arguments['params'] - emissions = bound_args.arguments['emissions'] - inputs = bound_args.arguments['inputs'] + params = bound_args.arguments["params"] + emissions = bound_args.arguments["emissions"] + inputs = bound_args.arguments["inputs"] num_timesteps = len(emissions) full_params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) return f(full_params, emissions, inputs=inputs) + return wrapper @@ -290,11 +327,10 @@ def lgssm_joint_sample( params: ParamsLGSSM, key: PRNGKey, num_timesteps: int, - inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None -)-> Tuple[Float[Array, "num_timesteps state_dim"], - Float[Array, "num_timesteps emission_dim"]]: + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None, +) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample from the joint distribution to produce state and emission trajectories. - + Args: params: model parameters inputs: optional array of inputs. @@ -303,7 +339,7 @@ def lgssm_joint_sample( latent states and emissions """ - + params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) def _sample_transition(key, F, B, b, Q, x, u): @@ -313,7 +349,7 @@ def _sample_transition(key, F, B, b, Q, x, u): def _sample_emission(key, H, D, d, R, x, u): mean = H @ x + D @ u + d return MVN(mean, R).sample(seed=key) - + def _sample_initial(key, params, inputs): key1, key2 = jr.split(key) @@ -350,7 +386,7 @@ def _step(state_prev, args): # Sample the initial state key1, key2 = jr.split(key) - + initial_state, initial_emission = _sample_initial(key1, params, inputs) # Sample the remaining emissions and states @@ -370,8 +406,8 @@ def _step(state_prev, args): @preprocess_args def lgssm_filter( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates. @@ -424,7 +460,7 @@ def _step(carry, t): def lgssm_smoother( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: r"""Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this @@ -494,9 +530,9 @@ def _step(carry, args): def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None, - jitter: Optional[Scalar]=0 + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, + jitter: Optional[Scalar] = 0, ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index f2548018..b0ba984e 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -1,4 +1,4 @@ -''' +""" Parallel filtering and smoothing for a lgssm. This implementation is adapted from the work of Adrien Correnflos: @@ -28,7 +28,7 @@ | | | | Y₀ Y₁ Y₂ Y₃ -''' +""" import jax.numpy as jnp from jax import vmap, lax @@ -51,10 +51,12 @@ def _get_params(x, dim, t): return x[t] else: return x - -#---------------------------------------------------------------------------# + + +# ---------------------------------------------------------------------------# # Filtering # -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# + class FilterMessage(NamedTuple): """ @@ -67,11 +69,12 @@ class FilterMessage(NamedTuple): J: P(z_{i-1} | y_{i:j}) covariance. eta: P(z_{i-1} | y_{i:j}) mean. """ - A: Float[Array, "ntime state_dim state_dim"] - b: Float[Array, "ntime state_dim"] - C: Float[Array, "ntime state_dim state_dim"] - J: Float[Array, "ntime state_dim state_dim"] - eta: Float[Array, "ntime state_dim"] + + A: Float[Array, "ntime state_dim state_dim"] + b: Float[Array, "ntime state_dim"] + C: Float[Array, "ntime state_dim state_dim"] + J: Float[Array, "ntime state_dim state_dim"] + eta: Float[Array, "ntime state_dim"] logZ: Float[Array, "ntime"] @@ -102,7 +105,6 @@ def _first_message(params, y, u): logZ = -MVN(loc=H @ m + d, covariance_matrix=H @ P @ H.T + R).log_prob(y) return A, b, C, J, eta, logZ - @partial(vmap, in_axes=(None, 0, 0)) def _generic_message(params, y, t): F_prev = _get_params(params.dynamics.weights, 2, t - 1) @@ -136,7 +138,6 @@ def _generic_message(params, y, t): logZ = -MVN(loc=mu_y, covariance_matrix=S).log_prob(y) return A, b_prev, C, J, eta, logZ - A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0]) At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(1, len(emissions))) @@ -146,21 +147,21 @@ def _generic_message(params, y, t): C=jnp.concatenate([C0[None], Ct]), J=jnp.concatenate([J0[None], Jt]), eta=jnp.concatenate([eta0[None], etat]), - logZ=jnp.concatenate([logZ0[None], logZt]) + logZ=jnp.concatenate([logZ0[None], logZt]), ) - @preprocess_args def lgssm_filter( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: """A parallel version of the lgssm filtering algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. """ + @vmap def _operator(elem1, elem2): A1, b1, C1, J1, eta1, logZ1 = elem1 @@ -179,8 +180,8 @@ def _operator(elem1, elem2): J = symmetrize(temp @ J2 @ A1 + J1) mu = jnp.linalg.solve(C1, b1) - t1 = (b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1)) - logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1) + t1 = b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1) + logZ = logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1 return FilterMessage(A, b, C, J, eta, logZ) initial_messages = _initialize_filtering_messages(params, emissions, inputs) @@ -189,12 +190,14 @@ def _operator(elem1, elem2): return PosteriorGSSMFiltered( marginal_loglik=-final_messages.logZ[-1], filtered_means=final_messages.b, - filtered_covariances=final_messages.C) + filtered_covariances=final_messages.C, + ) -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# # Smoothing # -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# + class SmoothMessage(NamedTuple): """ @@ -205,6 +208,7 @@ class SmoothMessage(NamedTuple): g: P(z_i | y_{1:j}, z_{j+1}) bias. L: P(z_i | y_{1:j}, z_{j+1}) covariance. """ + E: Float[Array, "ntime state_dim state_dim"] g: Float[Array, "ntime state_dim"] L: Float[Array, "ntime state_dim state_dim"] @@ -229,17 +233,19 @@ def _generic_message(params, m, P, t): CF, low = cho_factor(F @ P @ F.T + Q) E = cho_solve((CF, low), F @ P).T - g = m - E @ (F @ m + b) - L = symmetrize(P - E @ F @ P) + g = m - E @ (F @ m + b) + L = symmetrize(P - E @ F @ P) return E, g, L - + En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) - Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) - + Et, gt, Lt = _generic_message( + params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means) - 1) + ) + return SmoothMessage( E=jnp.concatenate([Et, En[None]]), g=jnp.concatenate([gt, gn[None]]), - L=jnp.concatenate([Lt, Ln[None]]) + L=jnp.concatenate([Lt, Ln[None]]), ) @@ -247,7 +253,7 @@ def _generic_message(params, m, P, t): def lgssm_smoother( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: """A parallel version of the lgssm smoothing algorithm. @@ -256,7 +262,7 @@ def lgssm_smoother( filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances - + @vmap def _operator(elem1, elem2): E1, g1, L1 = elem1 @@ -297,9 +303,10 @@ def compute_smoothed_cross_covariances( return G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# # Sampling # -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# + class SampleMessage(NamedTuple): """ @@ -309,14 +316,15 @@ class SampleMessage(NamedTuple): E: z_i ~ z_{j+1} weights. h: z_i ~ z_{j+1} bias. """ + E: Float[Array, "ntime state_dim state_dim"] h: Float[Array, "ntime state_dim"] def _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_covariances): """A parallel version of the lgssm sampling algorithm. - - Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, + + Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`. """ E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) @@ -327,7 +335,7 @@ def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Float[Array, "ntime state_dim"]: """A parallel version of the lgssm sampling algorithm. diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index bcf7f30e..cd40a8f4 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -11,12 +11,13 @@ from dynamax.linear_gaussian_ssm import parallel_lgssm_posterior_sample +from jax.config import config -from jax.config import config; config.update("jax_enable_x64", True) +config.update("jax_enable_x64", True) allclose = partial(jnp.allclose, atol=1e-2, rtol=1e-2) - + def make_static_lgssm_params(): latent_dim = 4 @@ -26,33 +27,33 @@ def make_static_lgssm_params(): dt = 0.1 F = jnp.eye(4) + dt * jnp.eye(4, k=2) b = 0.1 * jnp.arange(4) - Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], - [dt**2/2, dt]]), - jnp.eye(2)) - + Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(2)) + H = jnp.eye(2, 4) d = 0.1 * jnp.ones(2) - R = 0.5 ** 2 * jnp.eye(2) - μ0 = jnp.array([0.,1.,1.,-1.]) + R = 0.5**2 * jnp.eye(2) + μ0 = jnp.array([0.0, 1.0, 1.0, -1.0]) Σ0 = jnp.eye(4) B = jnp.eye(latent_dim, input_dim) * 0 D = jnp.eye(observation_dim, input_dim) * 0 lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) - params, _ = lgssm.initialize(jr.PRNGKey(0), - initial_mean=μ0, - initial_covariance= Σ0, - dynamics_weights=F, - dynamics_input_weights=B, - dynamics_bias=b, - dynamics_covariance=Q, - emission_weights=H, - emission_input_weights=D, - emission_bias=d, - emission_covariance=R) + params, _ = lgssm.initialize( + jr.PRNGKey(0), + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_input_weights=B, + dynamics_bias=b, + dynamics_covariance=Q, + emission_weights=H, + emission_input_weights=D, + emission_bias=d, + emission_covariance=R, + ) return params, lgssm - + def make_dynamic_lgssm_params(num_timesteps): latent_dim = 4 @@ -66,12 +67,10 @@ def make_dynamic_lgssm_params(num_timesteps): dt = 0.1 f_scale = jr.normal(key_f, (num_timesteps,)) * 0.5 - F = f_scale[:,None,None] * jnp.tile(jnp.eye(latent_dim), (num_timesteps, 1, 1)) + F = f_scale[:, None, None] * jnp.tile(jnp.eye(latent_dim), (num_timesteps, 1, 1)) F += dt * jnp.eye(latent_dim, k=2) - Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], - [dt**2/2, dt]]), - jnp.eye(latent_dim // 2)) + Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(latent_dim // 2)) assert Q.shape[-1] == latent_dim Q = Q[None] * jr.uniform(keys[3], (num_timesteps, 1, 1)) @@ -79,38 +78,42 @@ def make_dynamic_lgssm_params(num_timesteps): H = jr.normal(keys[4], (num_timesteps, observation_dim, latent_dim)) r_scale = jr.normal(key_r, (num_timesteps,)) * 0.1 - R = (r_scale**2)[:,None,None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) - - μ0 = jnp.array([1.,-2.,1.,-1.]) + R = (r_scale**2)[:, None, None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) + + μ0 = jnp.array([1.0, -2.0, 1.0, -1.0]) Σ0 = jnp.eye(latent_dim) B = jnp.eye(latent_dim, input_dim)[None] + 0.1 * jr.normal(keys[6], (num_timesteps, latent_dim, input_dim)) - D = jnp.eye(observation_dim, input_dim)[None] + 0.1 * jr.normal(keys[7], (num_timesteps, observation_dim, input_dim)) + D = jnp.eye(observation_dim, input_dim)[None] + 0.1 * jr.normal( + keys[7], (num_timesteps, observation_dim, input_dim) + ) b = jr.normal(keys[0], (num_timesteps, latent_dim)) d = jr.normal(keys[1], (num_timesteps, observation_dim)) lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) - params, _ = lgssm.initialize(key_init, - initial_mean=μ0, - initial_covariance=Σ0, - dynamics_weights=F, - dynamics_input_weights=B, - dynamics_bias=b, - dynamics_covariance=Q, - emission_weights=H, - emission_input_weights=D, - emission_bias=d, - emission_covariance=R) + params, _ = lgssm.initialize( + key_init, + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_input_weights=B, + dynamics_bias=b, + dynamics_covariance=Q, + emission_weights=H, + emission_input_weights=D, + emission_bias=d, + emission_covariance=R, + ) return params, lgssm class TestParallelLGSSMSmoother: - """ Compare parallel and serial lgssm smoothing implementations.""" - + """Compare parallel and serial lgssm smoothing implementations.""" + num_timesteps = 50 keys = jr.split(jr.PRNGKey(1), 2) - params, lgssm = make_static_lgssm_params() + params, lgssm = make_static_lgssm_params() inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) @@ -139,16 +142,16 @@ def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik) - class TestTimeVaryingParallelLGSSMSmoother: """Compare parallel and serial time-varying lgssm smoothing implementations. - + Vary dynamics weights and observation covariances with time. """ + num_timesteps = 50 keys = jr.split(jr.PRNGKey(1), 2) - params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params, lgssm = make_dynamic_lgssm_params(num_timesteps) inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) @@ -177,14 +180,13 @@ def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, rtol=2e-2) - -class TestTimeVaryingParallelLGSSMSampler(): +class TestTimeVaryingParallelLGSSMSampler: """Compare parallel and serial lgssm posterior sampling implementations in expectation.""" - + num_timesteps = 50 keys = jr.split(jr.PRNGKey(1), 2) - params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params, lgssm = make_dynamic_lgssm_params(num_timesteps) inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) @@ -193,10 +195,12 @@ class TestTimeVaryingParallelLGSSMSampler(): parallel_keys = jr.split(jr.PRNGKey(3), num_samples) serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0, None, None, None))( - serial_keys, params, emissions, inputs) - + serial_keys, params, emissions, inputs + ) + parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( - parallel_keys, params, emissions, inputs) + parallel_keys, params, emissions, inputs + ) def test_sampled_means(self): serial_mean = self.serial_samples.mean(axis=0) From 73a827056c1287d03b5f94bbf5a437d2edac2aa3 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Sun, 13 Aug 2023 22:23:55 +0100 Subject: [PATCH 11/17] use same indexing everywhere for LGSSM and apply black to these files --- dynamax/linear_gaussian_ssm/inference.py | 16 +- dynamax/linear_gaussian_ssm/models.py | 247 +++++++++--------- .../linear_gaussian_ssm/parallel_inference.py | 2 +- 3 files changed, 136 insertions(+), 129 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index cb693333..e6063aa5 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -16,12 +16,12 @@ class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution - $$p(z_1) = \mathcal{N}(z_1 \mid \mu_1, Q_1)$$ + $$p(z_0) = \mathcal{N}(z_0 \mid \mu_0, Q_0)$$ The tuple doubles as a container for the ParameterProperties. - :param mean: $\mu_1$ - :param cov: $Q_1$ + :param mean: $\mu_0$ + :param cov: $Q_0$ """ mean: Union[Float[Array, "state_dim"], ParameterProperties] @@ -32,7 +32,7 @@ class ParamsLGSSMInitial(NamedTuple): class ParamsLGSSMDynamics(NamedTuple): r"""Parameters of the emission distribution - $$p(z_{t+1} \mid z_t, u_t) = \mathcal{N}(z_{t+1} \mid F z_t + B u_t + b, Q)$$ + $$p(z_{t+1} \mid z_t, u_{t+1}) = \mathcal{N}(z_{t+1} \mid F_{t+1} z_t + B_{t+1} u_{t+1} + b_{t+1}, Q_{t+1})$$ The tuple doubles as a container for the ParameterProperties. @@ -68,7 +68,7 @@ class ParamsLGSSMDynamics(NamedTuple): class ParamsLGSSMEmissions(NamedTuple): r"""Parameters of the emission distribution - $$p(y_t \mid z_t, u_t) = \mathcal{N}(y_t \mid H z_t + D u_t + d, R)$$ + $$p(y_t \mid z_t, u_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$ The tuple doubles as a container for the ParameterProperties. @@ -203,8 +203,8 @@ def make_lgssm_params( def _predict(m, S, F, B, b, Q, u): r"""Predict next mean and covariance under a linear Gaussian model. - p(z_{t+1}) = int N(z_t \mid m, S) N(z_{t+1} \mid Fz_t + Bu + b, Q) - = N(z_{t+1} \mid Fm + Bu + b, F S F^T + Q) + p(z_{t+1}) = \int N(z_t \mid m_t, S_t) N(z_{t+1} \mid F_{t+1} z_t + B_{t+1} u_{t+1} + b_{t+1}, Q_{t+1}) d z_t + = N(z_{t+1} \mid F_{t+1} m_t + B_{t+1} u_{t+1} + b_{t+1}, F_{t+1} S_t F_{t+1}^T + Q_{t+1}) Args: m (D_hid,): prior mean. @@ -228,7 +228,7 @@ def _condition_on(m, P, H, D, d, R, u, y): r"""Condition a Gaussian potential on a new linear Gaussian observation p(z_t \mid y_t, u_t, y_{1:t-1}, u_{1:t-1}) propto p(z_t \mid y_{1:t-1}, u_{1:t-1}) p(y_t \mid z_t, u_t) - = N(z_t \mid m, P) N(y_t \mid H_t z_t + D_t u_t + d_t, R_t) + = N(z_t \mid m_t, P_t) N(y_t \mid H_t z_t + D_t u_t + d_t, R_t) = N(z_t \mid mm, PP) where mm = m + K*(y - yhat) = mu_cond diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index d3ee1583..cd382272 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -17,7 +17,12 @@ from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample -from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions +from dynamax.linear_gaussian_ssm.inference import ( + ParamsLGSSM, + ParamsLGSSMInitial, + ParamsLGSSMDynamics, + ParamsLGSSMEmissions, +) from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed from dynamax.parameters import ParameterProperties, ParameterSet from dynamax.types import PRNGKey, Scalar @@ -27,8 +32,10 @@ from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update from dynamax.utils.utils import pytree_stack, psd_solve + class SuffStatsLGSSM(Protocol): """A :class:`NamedTuple` with sufficient statistics for LGSSM parameter estimation.""" + pass @@ -38,7 +45,7 @@ class LinearGaussianSSM(SSM): The model is defined as follows - $$p(z_1) = \mathcal{N}(z_1 \mid m, S)$$ + $$p(z_0) = \mathcal{N}(z_0 \mid m, S)$$ $$p(z_t \mid z_{t-1}, u_t) = \mathcal{N}(z_t \mid F_t z_{t-1} + B_t u_t + b_t, Q_t)$$ $$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$ @@ -61,6 +68,10 @@ class LinearGaussianSSM(SSM): The parameters of the model are stored in a :class:`ParamsLGSSM`. You can create the parameters manually, or by calling :meth:`initialize`. + Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics", + rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM with initial + index begin 0 instead of 1, which tends to be a source of confusion sometimes. + :param state_dim: Dimensionality of latent state. :param emission_dim: Dimensionality of observation vector. :param input_dim: Dimensionality of input vector. Defaults to 0. @@ -70,14 +81,15 @@ class LinearGaussianSSM(SSM): and sampling instead of sequential ones. Defaults to False. """ + def __init__( self, state_dim: int, emission_dim: int, - input_dim: int=0, - has_dynamics_bias: bool=True, - has_emissions_bias: bool=True, - use_parallel_inference: bool=False + input_dim: int = 0, + has_dynamics_bias: bool = True, + has_emissions_bias: bool = True, + use_parallel_inference: bool = False, ): self.state_dim = state_dim self.emission_dim = emission_dim @@ -96,8 +108,8 @@ def inputs_shape(self): def initialize( self, - key: PRNGKey =jr.PRNGKey(0), - initial_mean: Optional[Float[Array, "state_dim"]]=None, + key: PRNGKey = jr.PRNGKey(0), + initial_mean: Optional[Float[Array, "state_dim"]] = None, initial_covariance=None, dynamics_weights=None, dynamics_bias=None, @@ -106,7 +118,7 @@ def initialize( emission_weights=None, emission_bias=None, emission_input_weights=None, - emission_covariance=None + emission_covariance=None, ) -> Tuple[ParamsLGSSM, ParamsLGSSM]: r"""Initialize model parameters that are set to None, and their corresponding properties. @@ -145,42 +157,44 @@ def initialize( # Create nested dictionary of params params = ParamsLGSSM( initial=ParamsLGSSMInitial( - mean=default(initial_mean, _initial_mean), - cov=default(initial_covariance, _initial_covariance)), + mean=default(initial_mean, _initial_mean), cov=default(initial_covariance, _initial_covariance), + ), dynamics=ParamsLGSSMDynamics( weights=default(dynamics_weights, _dynamics_weights), bias=default(dynamics_bias, _dynamics_bias), input_weights=default(dynamics_input_weights, _dynamics_input_weights), - cov=default(dynamics_covariance, _dynamics_covariance)), + cov=default(dynamics_covariance, _dynamics_covariance), + ), emissions=ParamsLGSSMEmissions( weights=default(emission_weights, _emission_weights), bias=default(emission_bias, _emission_bias), input_weights=default(emission_input_weights, _emission_input_weights), - cov=default(emission_covariance, _emission_covariance)) - ) + cov=default(emission_covariance, _emission_covariance), + ), + ) # The keys of param_props must match those of params! props = ParamsLGSSM( initial=ParamsLGSSMInitial( - mean=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())), + mean=ParameterProperties(), cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), dynamics=ParamsLGSSMDynamics( weights=ParameterProperties(), bias=ParameterProperties(), input_weights=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())), + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), emissions=ParamsLGSSMEmissions( weights=ParameterProperties(), bias=ParameterProperties(), input_weights=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())) - ) + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), + ) return params, props def initial_distribution( - self, - params: ParamsLGSSM, - inputs: Optional[Float[Array, "ntime input_dim"]]=None + self, params: ParamsLGSSM, inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: return MVN(params.initial.mean, params.initial.cov) @@ -188,7 +202,7 @@ def transition_distribution( self, params: ParamsLGSSM, state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.dynamics.weights @ state + params.dynamics.input_weights @ inputs @@ -200,7 +214,7 @@ def emission_distribution( self, params: ParamsLGSSM, state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs @@ -212,7 +226,7 @@ def marginal_log_prob( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Scalar: filtered_posterior = self.filter(params, emissions, inputs) return filtered_posterior.marginal_loglik @@ -221,7 +235,7 @@ def filter( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: if self.use_parallel_inference: return parallel_lgssm_filter(params, emissions, inputs) @@ -232,7 +246,7 @@ def smoother( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: if self.use_parallel_inference: return parallel_lgssm_smoother(params, emissions, inputs) @@ -244,7 +258,7 @@ def posterior_sample( key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Float[Array, "ntime state_dim"]: if use_parallel_inference: return parallel_lgssm_posterior_sample(key, params, emissions, inputs) @@ -255,7 +269,7 @@ def posterior_predictive( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Tuple[Float[Array, "ntime emission_dim"], Float[Array, "ntime emission_dim"]]: r"""Compute marginal posterior predictive smoothing distribution for each observation. @@ -275,18 +289,19 @@ def posterior_predictive( emission_dim = R.shape[0] smoothed_emissions = posterior.smoothed_means @ H.T + b smoothed_emissions_cov = H @ posterior.smoothed_covariances @ H.T + R - smoothed_emissions_std = jnp.sqrt( - jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])) + smoothed_emissions_std = jnp.sqrt(jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])) return smoothed_emissions, smoothed_emissions_std # Expectation-maximization (EM) code def e_step( self, params: ParamsLGSSM, - emissions: Union[Float[Array, "num_timesteps emission_dim"], - Float[Array, "num_batches num_timesteps emission_dim"]], - inputs: Optional[Union[Float[Array, "num_timesteps input_dim"], - Float[Array, "num_batches num_timesteps input_dim"]]]=None, + emissions: Union[ + Float[Array, "num_timesteps emission_dim"], Float[Array, "num_batches num_timesteps emission_dim"], + ], + inputs: Optional[ + Union[Float[Array, "num_timesteps input_dim"], Float[Array, "num_batches num_timesteps input_dim"],] + ] = None, ) -> Tuple[SuffStatsLGSSM, Scalar]: num_timesteps = emissions.shape[0] if inputs is None: @@ -319,18 +334,17 @@ def e_step( # let zp[t] = [x[t], u[t]] for t = 0...T-2 # let xn[t] = x[t+1] for t = 0...T-2 sum_zpzpT = jnp.block([[Exp.T @ Exp, Exp.T @ up], [up.T @ Exp, up.T @ up]]) - sum_zpzpT = sum_zpzpT.at[:self.state_dim, :self.state_dim].add(Vxp.sum(0)) + sum_zpzpT = sum_zpzpT.at[: self.state_dim, : self.state_dim].add(Vxp.sum(0)) sum_zpxnT = jnp.block([[Expxn.sum(0)], [up.T @ Exn]]) sum_xnxnT = Vxn.sum(0) + Exn.T @ Exn dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, num_timesteps - 1) # more expected sufficient statistics for the emissions # let z[t] = [x[t], u[t]] for t = 0...T-1 sum_zzT = jnp.block([[Ex.T @ Ex, Ex.T @ u], [u.T @ Ex, u.T @ u]]) - sum_zzT = sum_zzT.at[:self.state_dim, :self.state_dim].add(Vx.sum(0)) + sum_zzT = sum_zzT.at[: self.state_dim, : self.state_dim].add(Vx.sum(0)) sum_zyT = jnp.block([[Ex.T @ y], [u.T @ y]]) sum_yyT = emissions.T @ emissions emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps) @@ -339,22 +353,12 @@ def e_step( return (init_stats, dynamics_stats, emission_stats), posterior.marginal_loglik - - def initialize_m_step_state( - self, - params: ParamsLGSSM, - props: ParamsLGSSM - ) -> Any: + def initialize_m_step_state(self, params: ParamsLGSSM, props: ParamsLGSSM) -> Any: return None def m_step( - self, - params: ParamsLGSSM, - props: ParamsLGSSM, - batch_stats: SuffStatsLGSSM, - m_step_state: Any + self, params: ParamsLGSSM, props: ParamsLGSSM, batch_stats: SuffStatsLGSSM, m_step_state: Any ) -> Tuple[ParamsLGSSM, Any]: - def fit_linear_regression(ExxT, ExyT, EyyT, N): # Solve a linear regression given sufficient statistics W = psd_solve(ExxT, ExyT).T @@ -371,19 +375,17 @@ def fit_linear_regression(ExxT, ExyT, EyyT, N): m = sum_x0 / N FB, Q = fit_linear_regression(*dynamics_stats) - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], None) + F = FB[:, : self.state_dim] + B, b = (FB[:, self.state_dim : -1], FB[:, -1]) if self.has_dynamics_bias else (FB[:, self.state_dim :], None) HD, R = fit_linear_regression(*emission_stats) - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], None) + H = HD[:, : self.state_dim] + D, d = (HD[:, self.state_dim : -1], HD[:, -1]) if self.has_emissions_bias else (HD[:, self.state_dim :], None) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params, m_step_state @@ -405,40 +407,51 @@ class LinearGaussianConjugateSSM(LinearGaussianSSM): :param has_emissions_bias: Whether model contains an offset term d. Defaults to True. """ - def __init__(self, - state_dim, - emission_dim, - input_dim=0, - has_dynamics_bias=True, - has_emissions_bias=True, - **kw_priors): - super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim, - has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias) + + def __init__( + self, state_dim, emission_dim, input_dim=0, has_dynamics_bias=True, has_emissions_bias=True, **kw_priors, + ): + super().__init__( + state_dim=state_dim, + emission_dim=emission_dim, + input_dim=input_dim, + has_dynamics_bias=has_dynamics_bias, + has_emissions_bias=has_emissions_bias, + ) # Initialize prior distributions def default_prior(arg, default): return kw_priors[arg] if arg in kw_priors else default self.initial_prior = default_prior( - 'initial_prior', - NIW(loc=jnp.zeros(self.state_dim), - mean_concentration=1., + "initial_prior", + NIW( + loc=jnp.zeros(self.state_dim), + mean_concentration=1.0, df=self.state_dim + 0.1, - scale=jnp.eye(self.state_dim))) + scale=jnp.eye(self.state_dim), + ), + ) self.dynamics_prior = default_prior( - 'dynamics_prior', - MNIW(loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)), - col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias), - df=self.state_dim + 0.1, - scale=jnp.eye(self.state_dim))) + "dynamics_prior", + MNIW( + loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)), + col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias), + df=self.state_dim + 0.1, + scale=jnp.eye(self.state_dim), + ), + ) self.emission_prior = default_prior( - 'emission_prior', - MNIW(loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)), - col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias), - df=self.emission_dim + 0.1, - scale=jnp.eye(self.emission_dim))) + "emission_prior", + MNIW( + loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)), + col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias), + df=self.emission_dim + 0.1, + scale=jnp.eye(self.emission_dim), + ), + ) @property def emission_shape(self): @@ -448,39 +461,23 @@ def emission_shape(self): def covariates_shape(self): return dict(inputs=(self.input_dim,)) if self.input_dim > 0 else dict() - def log_prior( - self, - params: ParamsLGSSM - ) -> Scalar: + def log_prior(self, params: ParamsLGSSM) -> Scalar: lp = self.initial_prior.log_prob((params.initial.cov, params.initial.mean)) # dynamics dynamics_bias = params.dynamics.bias if self.has_dynamics_bias else jnp.zeros((self.state_dim, 0)) - dynamics_matrix = jnp.column_stack((params.dynamics.weights, - params.dynamics.input_weights, - dynamics_bias)) + dynamics_matrix = jnp.column_stack((params.dynamics.weights, params.dynamics.input_weights, dynamics_bias)) lp += self.dynamics_prior.log_prob((params.dynamics.cov, dynamics_matrix)) emission_bias = params.emissions.bias if self.has_emissions_bias else jnp.zeros((self.emission_dim, 0)) - emission_matrix = jnp.column_stack((params.emissions.weights, - params.emissions.input_weights, - emission_bias)) + emission_matrix = jnp.column_stack((params.emissions.weights, params.emissions.input_weights, emission_bias)) lp += self.emission_prior.log_prob((params.emissions.cov, emission_matrix)) return lp - def initialize_m_step_state( - self, - params: ParamsLGSSM, - props: ParamsLGSSM - ) -> Any: + def initialize_m_step_state(self, params: ParamsLGSSM, props: ParamsLGSSM) -> Any: return None - def m_step( - self, - params: ParamsLGSSM, - props: ParamsLGSSM, - batch_stats: SuffStatsLGSSM, - m_step_state: Any): + def m_step(self, params: ParamsLGSSM, props: ParamsLGSSM, batch_stats: SuffStatsLGSSM, m_step_state: Any): # Sum the statistics across all batches stats = tree_map(partial(jnp.sum, axis=0), batch_stats) init_stats, dynamics_stats, emission_stats = stats @@ -491,20 +488,26 @@ def m_step( dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) Q, FB = dynamics_posterior.mode() - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + F = FB[:, : self.state_dim] + B, b = ( + (FB[:, self.state_dim : -1], FB[:, -1]) + if self.has_dynamics_bias + else (FB[:, self.state_dim :], jnp.zeros(self.state_dim)) + ) emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) R, HD = emission_posterior.mode() - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + H = HD[:, : self.state_dim] + D, d = ( + (HD[:, self.state_dim : -1], HD[:, -1]) + if self.has_emissions_bias + else (HD[:, self.state_dim :], jnp.zeros(self.emission_dim)) + ) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params, m_step_state @@ -514,7 +517,7 @@ def fit_blocked_gibbs( initial_params: ParamsLGSSM, sample_size: int, emissions: Float[Array, "nbatch ntime emission_dim"], - inputs: Optional[Float[Array, "nbatch ntime input_dim"]]=None + inputs: Optional[Float[Array, "nbatch ntime input_dim"]] = None, ) -> ParamsLGSSM: r"""Estimate parameter posterior using block-Gibbs sampler. @@ -550,8 +553,7 @@ def sufficient_stats_from_sample(states): sum_xnxnT = xn.T @ xn dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, num_timesteps - 1) # Quantities for the emissions # Let z[t] = [x[t], u[t]] for t = 0...T-1 @@ -576,21 +578,27 @@ def lgssm_params_sample(rng, stats): # Sample the dynamics params dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) Q, FB = dynamics_posterior.sample(seed=next(rngs)) - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + F = FB[:, : self.state_dim] + B, b = ( + (FB[:, self.state_dim : -1], FB[:, -1]) + if self.has_dynamics_bias + else (FB[:, self.state_dim :], jnp.zeros(self.state_dim)) + ) # Sample the emission params emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) R, HD = emission_posterior.sample(seed=next(rngs)) - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + H = HD[:, : self.state_dim] + D, d = ( + (HD[:, self.state_dim : -1], HD[:, -1]) + if self.has_emissions_bias + else (HD[:, self.state_dim :], jnp.zeros(self.emission_dim)) + ) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params @@ -603,7 +611,6 @@ def one_sample(_params, rng): _stats = sufficient_stats_from_sample(states) return lgssm_params_sample(rngs[1], _stats) - sample_of_params = [] keys = iter(jr.split(key, sample_size)) current_params = initial_params diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index b0ba984e..3a55974f 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -21,7 +21,7 @@ Dynamax - F₀,Q₀ F₁,Q₁ F₂,Q₂ + F₁,Q₁ F₂,Q₂ F₃,Q₃ Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────... | | | | | H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃ From 7659d54b5b5386e66834208af1c7b1c3bb291722 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Mon, 14 Aug 2023 01:16:43 +0100 Subject: [PATCH 12/17] hopefully the final fix of indices --- dynamax/linear_gaussian_ssm/inference.py | 2 +- dynamax/linear_gaussian_ssm/parallel_inference_test.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index e949ef35..03f55696 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -551,7 +551,7 @@ def _step(carry, args): # Get parameters and inputs for time index t + 1 F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4] - u_next = inputs[t] + u_next = inputs[t + 1] # This is like the Kalman gain but in reverse # See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing" diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 10b0788c..7f73f7ad 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -96,9 +96,9 @@ def make_dynamic_lgssm_params(num_timesteps): b = jr.normal(keys[0], (num_timesteps, latent_dim)) d = jr.normal(keys[1], (num_timesteps, observation_dim)) - B = B * 0 - b = b * 0 - D = D * 0 + # B = B * 0 + # b = b * 0 + # D = D * 0 # d = d * 0 lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) @@ -186,10 +186,7 @@ class TestTimeVaryingParallelLGSSMSmoother: serial_posterior = serial_lgssm_smoother(params, emissions, inputs) parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) - params.emissions.bias.at[:].set(0) - dzero_parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) - assert jnp.allclose(dzero_parallel_posterior.marginal_loglik, parallel_posterior.marginal_loglik, rtol=2e-2) # ======= # params, lgssm = make_dynamic_lgssm_params(num_timesteps) # params_diag = flatten_diagonal_emission_cov(params) From c23b0768c1031ee5ec2e152484594b6e3346b0ac Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Mon, 14 Aug 2023 01:28:50 +0100 Subject: [PATCH 13/17] update docstring a bit --- dynamax/linear_gaussian_ssm/models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 88012fd7..b968d9a1 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -69,9 +69,10 @@ class LinearGaussianSSM(SSM): You can create the parameters manually, or by calling :meth:`initialize`. Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics", - rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM with initial - index begin 0 instead of 1, which tends to be a source of confusion sometimes. - As such, F_0, B_0, b_0, Q_0 are always ignored and the provided prior of the initial state is used. + rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start + initial index at 0 instead of 1, which is not exactly in line with the former book. This tends to be a source of + confusion sometimes. As such, $F_0$, $B_0$, $b_0$, $Q_0$ are always ignored and the prior specified by $m$ and $S$ + is used as the distribution of the initial state. :param state_dim: Dimensionality of latent state. :param emission_dim: Dimensionality of observation vector. From 3b25adc26c0a23c0a971b6405a83eabcd615a6ac Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Mon, 14 Aug 2023 20:14:51 +0100 Subject: [PATCH 14/17] remove a duplicated function --- .../linear_gaussian_ssm/parallel_inference.py | 61 ++++--------------- 1 file changed, 13 insertions(+), 48 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 328d645a..a84aa6c6 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -46,12 +46,12 @@ from jax.scipy.linalg import cho_solve, cho_factor from dynamax.utils.utils import symmetrize, psd_solve from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM -from dynamax.linear_gaussian_ssm.inference import preprocess_args, _get_one_param, _get_params +from dynamax.linear_gaussian_ssm.inference import preprocess_args, _get_one_param, _get_params, _log_likelihood -# ---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# # Filtering # -# ---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# def _emissions_scale(Q, H, R): @@ -80,35 +80,6 @@ def _emissions_scale(Q, H, R): return S_inv -# TODO: remove both and use the one defined in inference.py -def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): - m = H @ pred_mean + D @ u + d - if R.ndim == 2: - S = R + H @ pred_cov @ H.T - return MVN(m, S).log_prob(y) - else: - L = H @ jnp.linalg.cholesky(pred_cov) - return MVNLowRank(m, R, L).log_prob(y) - - -def _marginal_loglik_elem(Q, H, R, y, y_loc): - """Compute marginal log-likelihood elements. - - Args: - Q (state_dim, state_dim): State covariance. - H (emission_dim, state_dim): Emission matrix. - R (emission_dim, emission_dim) or (emission_dim,): Emission covariance. - y (emission_dim,): Emission. - y_loc (emission_dim,): Emission mean. - """ - if R.ndim == 2: - S = H @ Q @ H.T + R - return -MVN(y_loc, S).log_prob(y) - else: - L = H @ jnp.linalg.cholesky(Q) - return -MVNLowRank(y_loc, R, L).log_prob(y) - - class FilterMessage(NamedTuple): """ Filtering associative scan elements. @@ -139,19 +110,15 @@ def _first_message(params, y, u): m = params.initial.mean P = params.initial.cov - # Adjust the bias term accoding to the input - d = d + D @ u - S = H @ P @ H.T + (R if R.ndim == 2 else jnp.diag(R)) S_inv = _emissions_scale(P, H, R) K = P @ H.T @ S_inv A = jnp.zeros_like(P) - b = m + K @ (y - H @ m - d) + b = m + K @ (y - H @ m - D @ u - d) C = symmetrize(P - K @ S @ K.T) eta = jnp.zeros_like(b) J = jnp.eye(len(b)) - - logZ = _marginal_loglik_elem(P, H, R, y, y_loc=H @ m + d) + logZ = -_log_likelihood(m, P, H, D, d, R, u, y) return A, b, C, J, eta, logZ @partial(vmap, in_axes=(None, 0, 0)) @@ -160,22 +127,20 @@ def _generic_message(params, y, t): u = inputs[t] # Adjust the bias terms accoding to the input - d = d + D @ u b = b + B @ u - - y_loc = H @ b + d # mean of p(y_t|x_{t-1}=0) + m = b S_inv = _emissions_scale(Q, H, R) K = Q @ H.T @ S_inv - eta = F.T @ H.T @ S_inv @ (y - H @ b - d) + eta = F.T @ H.T @ S_inv @ (y - H @ b - D @ u - d) J = symmetrize(F.T @ H.T @ S_inv @ H @ F) A = F - K @ H @ F - b = b + K @ (y - H @ b - d) + b = b + K @ (y - H @ b - D @ u - d) C = symmetrize(Q - K @ H @ Q) - logZ = _marginal_loglik_elem(Q, H, R, y, y_loc) + logZ = -_log_likelihood(m, Q, H, D, d, R, u, y) return A, b, C, J, eta, logZ A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0]) @@ -234,9 +199,9 @@ def _operator(elem1, elem2): ) -# ---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# # Smoothing # -# ---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# class SmoothMessage(NamedTuple): @@ -342,9 +307,9 @@ def compute_smoothed_cross_covariances( return G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) -# ---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# # Sampling # -# ---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# class SampleMessage(NamedTuple): From a818f0917b5c7b622ef73f86cb11f1c1d7e0d651 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Mon, 14 Aug 2023 22:55:56 +0100 Subject: [PATCH 15/17] solve merge conflict of parallel_inference_test.py --- .../parallel_inference_test.py | 181 ++++++------------ 1 file changed, 63 insertions(+), 118 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 7f73f7ad..09456358 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -25,23 +25,26 @@ def make_static_lgssm_params(): observation_dim = 2 input_dim = 3 + keys = jr.split(jr.PRNGKey(0), 3) + dt = 0.1 F = jnp.eye(4) + dt * jnp.eye(4, k=2) - b = 0.1 * jnp.arange(4) + B = 0.2 * jr.normal(keys[0], (4, 3)) + b = 0.2 * jnp.arange(4) Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(2)) H = jnp.eye(2, 4) - d = 0.1 * jnp.ones(2) + D = 0.2 * jr.normal(keys[1], (observation_dim, input_dim)) + d = 0.2 * jnp.ones(2) R = 0.5**2 * jnp.eye(2) + μ0 = jnp.array([0.0, 1.0, 1.0, -1.0]) Σ0 = jnp.eye(4) - B = jnp.eye(latent_dim, input_dim) * 0 - D = jnp.eye(observation_dim, input_dim) * 0 lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize( - jr.PRNGKey(0), + keys[2], initial_mean=μ0, initial_covariance=Σ0, dynamics_weights=F, @@ -56,54 +59,33 @@ def make_static_lgssm_params(): return params, lgssm -# <<<<<<< HEAD def make_dynamic_lgssm_params(num_timesteps): latent_dim = 4 observation_dim = 2 input_dim = 3 - keys = jr.split(jr.PRNGKey(1), 100) - - key = jr.PRNGKey(0) - # ======= - # def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, seed=0): - # key = jr.PRNGKey(seed) - # >>>>>>> main - key, key_f, key_r, key_init = jr.split(key, 4) + keys = jr.split(jr.PRNGKey(1), 9) dt = 0.1 - f_scale = jr.normal(key_f, (num_timesteps,)) * 0.5 - F = f_scale[:, None, None] * jnp.tile(jnp.eye(latent_dim), (num_timesteps, 1, 1)) - F += dt * jnp.eye(latent_dim, k=2) - Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(latent_dim // 2)) - assert Q.shape[-1] == latent_dim - Q = Q[None] * jr.uniform(keys[3], (num_timesteps, 1, 1)) + F = jnp.eye(4)[None] + dt * jnp.eye(4, k=2)[None] + 0.1 * jr.normal(keys[0], (num_timesteps, latent_dim, latent_dim)) + B = 0.2 * jr.normal(keys[4], (num_timesteps, latent_dim, input_dim)) + b = 0.2 * jr.normal(keys[6], (num_timesteps, latent_dim)) + q_scale = jr.normal(keys[1], (num_timesteps, 1, 1)) ** 2 + Q = q_scale * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(2))[None] - # H = jnp.eye(observation_dim, latent_dim) - H = jr.normal(keys[4], (num_timesteps, observation_dim, latent_dim)) - - r_scale = jr.normal(key_r, (num_timesteps,)) * 0.1 - R = (r_scale**2)[:, None, None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) + H = jnp.eye(2, 4)[None] * 0.1 * jr.normal(keys[2], (num_timesteps, observation_dim, latent_dim)) + D = 0.2 * jr.normal(keys[5], (num_timesteps, observation_dim, input_dim)) + d = 0.2 * jr.normal(keys[7], (num_timesteps, observation_dim)) + r_scale = jr.normal(keys[3], (num_timesteps, 1, 1)) ** 2 + R = r_scale * jnp.eye(2)[None] μ0 = jnp.array([1.0, -2.0, 1.0, -1.0]) Σ0 = jnp.eye(latent_dim) - B = jnp.eye(latent_dim, input_dim)[None] + 0.1 * jr.normal(keys[6], (num_timesteps, latent_dim, input_dim)) - D = jnp.eye(observation_dim, input_dim)[None] + 0.1 * jr.normal( - keys[7], (num_timesteps, observation_dim, input_dim) - ) - b = jr.normal(keys[0], (num_timesteps, latent_dim)) - d = jr.normal(keys[1], (num_timesteps, observation_dim)) - - # B = B * 0 - # b = b * 0 - # D = D * 0 - # d = d * 0 - lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize( - key_init, + keys[8], initial_mean=μ0, initial_covariance=Σ0, dynamics_weights=F, @@ -124,50 +106,45 @@ class TestParallelLGSSMSmoother: num_timesteps = 50 keys = jr.split(jr.PRNGKey(1), 2) - # <<<<<<< HEAD params, lgssm = make_static_lgssm_params() - # inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) - # _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - - # serial_posterior = serial_lgssm_smoother(params, emissions, inputs) - # parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) - # ======= - # params, lgssm = make_static_lgssm_params() - # params_diag = flatten_diagonal_emission_cov(params) - # _, emissions = lgssm_joint_sample(params, key, num_timesteps) - # - # serial_posterior = serial_lgssm_smoother(params, emissions) - # parallel_posterior = parallel_lgssm_smoother(params, emissions) - # parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions) - # >>>>>>> main + params_diag = flatten_diagonal_emission_cov(params) + inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) + + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) - # assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means) + assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means) def test_filtered_covariances(self): assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior.filtered_covariances) - # assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances) + assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances) def test_smoothed_means(self): assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior.smoothed_means) - # assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means) + assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means) def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) - # assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) + assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) def test_smoothed_cross_covariances(self): x = self.serial_posterior.smoothed_cross_covariances y = self.parallel_posterior.smoothed_cross_covariances + z = self.parallel_posterior_diag.smoothed_cross_covariances matrix_norm_rel_diff = jnp.linalg.norm(x - y, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) assert allclose(matrix_norm_rel_diff, 0) + matrix_norm_rel_diff = jnp.linalg.norm(x - z, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) - # assert jnp.allclose( - # self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1 - # ) + assert jnp.allclose( + self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1 + ) class TestTimeVaryingParallelLGSSMSmoother: @@ -179,113 +156,81 @@ class TestTimeVaryingParallelLGSSMSmoother: num_timesteps = 50 keys = jr.split(jr.PRNGKey(1), 2) - # <<<<<<< HEAD params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params_diag = flatten_diagonal_emission_cov(params) inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) serial_posterior = serial_lgssm_smoother(params, emissions, inputs) parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) - - # ======= - # params, lgssm = make_dynamic_lgssm_params(num_timesteps) - # params_diag = flatten_diagonal_emission_cov(params) - # _, emissions = lgssm_joint_sample(params, key, num_timesteps) - # - # serial_posterior = serial_lgssm_smoother(params, emissions) - # parallel_posterior = parallel_lgssm_smoother(params, emissions) - # parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions) - # >>>>>>> main + parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) - # assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means) + assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means) def test_filtered_covariances(self): assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior.filtered_covariances) - # assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances) + assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances) def test_smoothed_means(self): assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior.smoothed_means) - # assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means) + assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means) def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) - # assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) + assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) def test_smoothed_cross_covariances(self): x = self.serial_posterior.smoothed_cross_covariances y = self.parallel_posterior.smoothed_cross_covariances + z = self.parallel_posterior_diag.smoothed_cross_covariances matrix_norm_rel_diff = jnp.linalg.norm(x - y, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) assert allclose(matrix_norm_rel_diff, 0) + matrix_norm_rel_diff = jnp.linalg.norm(x - z, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) def test_marginal_loglik(self): - # <<<<<<< HEAD assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, rtol=2e-2) - - -# ======= -# assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) -# assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) -# >>>>>>> main + assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, rtol=2e-2) class TestTimeVaryingParallelLGSSMSampler: """Compare parallel and serial lgssm posterior sampling implementations in expectation.""" - # <<<<<<< HEAD num_timesteps = 50 keys = jr.split(jr.PRNGKey(1), 2) params, lgssm = make_dynamic_lgssm_params(num_timesteps) - # inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) - # _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - # ======= - # params, lgssm = make_dynamic_lgssm_params(num_timesteps) - # params_diag = flatten_diagonal_emission_cov(params) - # _, emissions = lgssm_joint_sample(params_diag, key, num_timesteps) - # >>>>>>> main + params_diag = flatten_diagonal_emission_cov(params) + inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) num_samples = 1000 serial_keys = jr.split(jr.PRNGKey(2), num_samples) parallel_keys = jr.split(jr.PRNGKey(3), num_samples) - # <<<<<<< HEAD - # serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0, None, None, None))( - # serial_keys, params, emissions, inputs - # ) - - # parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( - # parallel_keys, params, emissions, inputs - # ) - # ======= - # serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0,None,None))( - # serial_keys, params, emissions) - # - # parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - # parallel_keys, params, emissions) - # - # parallel_samples_diag = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - # parallel_keys, params, emissions) - # >>>>>>> main + serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0, None, None, None))( + serial_keys, params, emissions, inputs + ) + parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( + parallel_keys, params, emissions, inputs + ) + parallel_samples_diag = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( + parallel_keys, params_diag, emissions, inputs + ) def test_sampled_means(self): serial_mean = self.serial_samples.mean(axis=0) parallel_mean = self.parallel_samples.mean(axis=0) parallel_mean_diag = self.parallel_samples.mean(axis=0) - assert allclose(serial_mean, parallel_mean, atol=1e-1) - assert allclose(serial_mean, parallel_mean_diag, atol=1e-1) + assert allclose(serial_mean, parallel_mean, atol=1e-1, rtol=1e-1) + assert allclose(serial_mean, parallel_mean_diag, atol=1e-1, rtol=1e-1) def test_sampled_covariances(self): # samples have shape (N, T, D): vmap over the T axis, calculate cov over N axis serial_cov = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.serial_samples) parallel_cov = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) - # <<<<<<< HEAD + parallel_cov_diag = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) assert allclose(serial_cov, parallel_cov, atol=1e-1) - - -# ======= -# parallel_cov_diag = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) -# assert allclose(serial_cov, parallel_cov, atol=1e-1) -# assert allclose(serial_cov, parallel_cov_diag, atol=1e-1) -# >>>>>>> main + assert allclose(serial_cov, parallel_cov_diag, atol=1e-1) From 1b19bc73b4e6c8e39f476229553316c5b8ef6b71 Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Mon, 14 Aug 2023 23:39:26 +0100 Subject: [PATCH 16/17] black --- dynamax/linear_gaussian_ssm/inference.py | 4 +++- dynamax/linear_gaussian_ssm/models.py | 2 +- dynamax/linear_gaussian_ssm/models_test.py | 3 ++- .../linear_gaussian_ssm/parallel_inference_test.py | 11 ++++++++--- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 03f55696..b782cf7a 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -621,7 +621,9 @@ def _step(carry, args): u_next = inputs[t + 1] # Condition on next state - smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, F_next, B_next, b_next, Q_next, u_next, next_state) + smoothed_mean, smoothed_cov = _condition_on( + filtered_mean, filtered_cov, F_next, B_next, b_next, Q_next, u_next, next_state + ) smoothed_cov = smoothed_cov + jnp.eye(smoothed_cov.shape[-1]) * jitter state = MVN(smoothed_mean, smoothed_cov).sample(seed=key) return state, state diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index b968d9a1..0efd0b06 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -69,7 +69,7 @@ class LinearGaussianSSM(SSM): You can create the parameters manually, or by calling :meth:`initialize`. Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics", - rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start + rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start initial index at 0 instead of 1, which is not exactly in line with the former book. This tends to be a source of confusion sometimes. As such, $F_0$, $B_0$, $b_0$, $Q_0$ are always ignored and the prior specified by $m$ and $S$ is used as the distribution of the initial state. diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index 6ac68ddd..1cee7358 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -14,10 +14,11 @@ (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=True), None), ] + @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) def test_sample_and_fit(cls, kwargs, inputs): model = cls(**kwargs) - #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) + # key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) key1, key2 = jr.split(jr.PRNGKey(0)) params, param_props = model.initialize(key1) states, emissions = model.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index 09456358..1fa67eea 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -41,7 +41,6 @@ def make_static_lgssm_params(): μ0 = jnp.array([0.0, 1.0, 1.0, -1.0]) Σ0 = jnp.eye(4) - lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) params, _ = lgssm.initialize( keys[2], @@ -68,7 +67,11 @@ def make_dynamic_lgssm_params(num_timesteps): dt = 0.1 - F = jnp.eye(4)[None] + dt * jnp.eye(4, k=2)[None] + 0.1 * jr.normal(keys[0], (num_timesteps, latent_dim, latent_dim)) + F = ( + jnp.eye(4)[None] + + dt * jnp.eye(4, k=2)[None] + + 0.1 * jr.normal(keys[0], (num_timesteps, latent_dim, latent_dim)) + ) B = 0.2 * jr.normal(keys[4], (num_timesteps, latent_dim, input_dim)) b = 0.2 * jr.normal(keys[6], (num_timesteps, latent_dim)) q_scale = jr.normal(keys[1], (num_timesteps, 1, 1)) ** 2 @@ -192,7 +195,9 @@ def test_smoothed_cross_covariances(self): def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, rtol=2e-2) - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, rtol=2e-2) + assert jnp.allclose( + self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, rtol=2e-2 + ) class TestTimeVaryingParallelLGSSMSampler: From 70ab368f3df0083a3b41d2b4f99e7339ceb285eb Mon Sep 17 00:00:00 2001 From: Kei Ishikawa Date: Tue, 15 Aug 2023 18:55:47 +0100 Subject: [PATCH 17/17] add parallel inference flag to conjugate LGSSM --- dynamax/linear_gaussian_ssm/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 0efd0b06..ed73e431 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -425,6 +425,7 @@ def __init__( input_dim=0, has_dynamics_bias=True, has_emissions_bias=True, + use_parallel_inference=False, **kw_priors, ): super().__init__( @@ -433,6 +434,7 @@ def __init__( input_dim=input_dim, has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias, + use_parallel_inference=use_parallel_inference, ) # Initialize prior distributions