Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel inference of LGSSM in the EM algorithm (+ some bug fixes) #336

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

kstoneriv3
Copy link

@kstoneriv3 kstoneriv3 commented Aug 13, 2023

Hi, I wanted to use parallel filtering and smoothing of LGSSM for the EM algorithm so I updated the parallel inference functions to the level of feature parity with serial filtering and smoothing.

During the implementation, I found a couple of bugs as well so this PR includes the bug fix as well. (They are joint sampling logic in inference.py and missing emission bias term in the log likelihood of parallel_inference.py).

I thought this branch is almost ready for PR but it seems that I am having a large conflict due to the recent diagonal covariance PR. I will mark the PR as ready when the conflict is resolved. Now ready for review!

@kstoneriv3 kstoneriv3 changed the title WIP: Use parallel inference of LGSSM in the EM algorithm (+ some bug fix) WIP: Use parallel inference of LGSSM in the EM algorithm (+ some bug fixes) Aug 13, 2023
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@kstoneriv3 kstoneriv3 marked this pull request as ready for review August 14, 2023 21:56
Comment on lines 70 to 76

Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics",
rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start
initial index at 0 instead of 1, which is not exactly in line with the former book. This tends to be a source of
confusion sometimes. As such, $F_0$, $B_0$, $b_0$, $Q_0$ are always ignored and the prior specified by $m$ and $S$
is used as the distribution of the initial state.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several conflicts of indexing style. I added this note here to make sure that this code base follows the notation of Murphy (2023), which is $z_t \sim \mathcal{N}(F_t z_{t - 1} + B_t u_t + b_t, Q_t)$.

By the way, I personally prefer Sarkka's style $z_{t+1} \sim \mathcal{N}(F_t z_t + B_t u_t + b_t, Q_t)$, and I believe it makes the codebase look a little simpler, and I don't mind modifying the whole codebase of LGSSM to adopt his indexing if the code owners prefer it.

Comment on lines -10 to +11
MultivariateNormalFullCovariance as MVN)
MultivariateNormalFullCovariance as MVN,
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a quite few lines of changes introduced by applying black to the modified files. Maybe it's better to first merge a separate PR of applying black, to make the diff easier to read here.

Copy link
Author

@kstoneriv3 kstoneriv3 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a PR for formatting here: #337 so feel free to use it if you feel like it. (Please close it otherwise!)

The diffs from the formatted branch in #337 and this PR are 278 insertions(+), and 201 deletions(-), which is less than half of the current diffs from the main branch.

Comment on lines +49 to -84
from dynamax.linear_gaussian_ssm.inference import preprocess_args, _get_one_param, _get_params, _log_likelihood


def _get_one_param(x, dim, t):
"""Helper function to get one parameter at time t."""
if callable(x):
return x(t)
elif x.ndim == dim + 1:
return x[t]
else:
return x

def _get_params(params, num_timesteps, t):
"""Helper function to get parameters at time t."""
assert not callable(params.emissions.cov), "Emission covariance cannot be a callable."

F = _get_one_param(params.dynamics.weights, 2, t)
b = _get_one_param(params.dynamics.bias, 1, t)
Q = _get_one_param(params.dynamics.cov, 2, t)
H = _get_one_param(params.emissions.weights, 2, t+1)
d = _get_one_param(params.emissions.bias, 1, t+1)

if len(params.emissions.cov.shape) == 1:
R = _get_one_param(params.emissions.cov, 1, t+1)
elif len(params.emissions.cov.shape) > 2:
R = _get_one_param(params.emissions.cov, 2, t+1)
elif params.emissions.cov.shape[0] != num_timesteps:
R = _get_one_param(params.emissions.cov, 2, t+1)
elif params.emissions.cov.shape[1] != num_timesteps:
R = _get_one_param(params.emissions.cov, 1, t+1)
else:
R = _get_one_param(params.emissions.cov, 2, t+1)
warnings.warn(
"Emission covariance has shape (N,N) where N is the number of timesteps. "
"The covariance will be interpreted as static and non-diagonal. To "
"specify a dynamic and diagonal covariance, pass it as a 3D array.")

return F, b, Q, H, d, R
Copy link
Author

@kstoneriv3 kstoneriv3 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remoted these kinds of duplicated utility functions in parallel_inference.py and used the one defined in inference.py


from jax.config import config

config.update("jax_enable_x64", True)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for marginal likelihood were quite unstable for float32, probably due instability of log det computation. I'd suggest enabling float64 as default for that reason.

@kstoneriv3 kstoneriv3 changed the title WIP: Use parallel inference of LGSSM in the EM algorithm (+ some bug fixes) Use parallel inference of LGSSM in the EM algorithm (+ some bug fixes) Aug 14, 2023
@kstoneriv3 kstoneriv3 changed the title Use parallel inference of LGSSM in the EM algorithm (+ some bug fixes) Parallel inference of LGSSM in the EM algorithm (+ some bug fixes) Aug 14, 2023
"""
if R.ndim == 2:
S = H @ Q @ H.T + R
return -MVN(jnp.zeros_like(y), S).log_prob(y)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bug fix: the bias term was missed here.

Comment on lines -524 to +554
# Get parameters and inputs for time index t
F, B, b, Q = _get_params(params, num_timesteps, t)[:4]
u = inputs[t]
# Get parameters and inputs for time index t + 1
F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4]
u_next = inputs[t + 1]
Copy link
Author

@kstoneriv3 kstoneriv3 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bug fix: calculation of the mean on the next time step requires parameters at the next time step (unless you use Sarkka (2013)'s indexing instead of Murphy (2023)'s).

@@ -12,86 +12,111 @@
from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases are updated to check if the parallel inference can handle inputs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the synthetic data for testing the time-varying case was too simplistic to capture some bugs while I was developing the code. So I updated the test case so that it has more time variation of parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant