Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tackle Typing and Linting Errors #379

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
df1c1d4
Change input and shape of LinRegHMM test to fix failure
gileshd Sep 12, 2024
7c252c2
Add ruff ignore F722
gileshd Sep 12, 2024
b7927fb
Prepend space in uni-dim jaxtyping hints
gileshd Sep 12, 2024
72bf8f5
Remove unused imports
gileshd Sep 12, 2024
a890d62
Fix LinearGaussianSSM.sample type hint
gileshd Sep 12, 2024
c4efc96
Add runtime_checkable decorator to Protocols
gileshd Sep 12, 2024
1ded8e9
Fix jr.PRNGKey type hints
gileshd Sep 12, 2024
c9109e8
Rename and change PRNGKey Type
gileshd Sep 12, 2024
1c2a9ed
Add IntScalar type
gileshd Sep 23, 2024
c79c23e
Fix type annotations in hmm parallel inference
gileshd Sep 12, 2024
d18c908
Convert strings with latex to raw-strings
gileshd Sep 13, 2024
1e60c84
Change type hints to jaxtyping
gileshd Sep 13, 2024
7d5b96f
Change attributes of `HHM*Set` to read-only.
gileshd Sep 17, 2024
5818b89
Change default arguments to floats
gileshd Sep 17, 2024
2e006b2
Update `keepdims` default value to match `jnp.sum`
gileshd Sep 17, 2024
7321b62
Help type-checker with max of max of arrays
gileshd Sep 17, 2024
36b4472
Fix `StandardHMMInitialState` __init__ typing
gileshd Sep 18, 2024
985a557
Enforce that either key or initial_probs specified
gileshd Sep 18, 2024
188aa97
[MAYBE] Call `cast` to handle tfd type spaghetti.
gileshd Sep 18, 2024
af3045c
Update type annotations in hmm base classes
gileshd Sep 23, 2024
cccec0d
Add futher type hints to hmm inference code
gileshd Sep 18, 2024
9d03900
Add further type annotations to categorical hmm
gileshd Sep 18, 2024
d15fe23
Add further type annotations to arhmm
gileshd Sep 20, 2024
b9604b1
Add further type annotations to linreghmm
gileshd Sep 20, 2024
92a1559
Add further type annotations to Bernoulli HMM
gileshd Sep 20, 2024
f117fbf
Add further type annotations to Gamma HMM
gileshd Sep 20, 2024
b099c43
Add further type annotations to Gaussian HMMs
gileshd Sep 20, 2024
25d19da
Add further type annotations to gmhmms
gileshd Sep 20, 2024
4646b42
Add further type annotations to logreg hmm
gileshd Sep 21, 2024
bc8c78c
Add further type annotations to multinomialhmm
gileshd Sep 21, 2024
da2395f
Add further type annotations to poisson hmm
gileshd Sep 23, 2024
47fe33f
Update type annotations in hmm inference code.
gileshd Sep 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax import jacfwd, vmap, lax
import jax.numpy as jnp
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable

Expand Down Expand Up @@ -83,7 +82,7 @@ def compute_weights_and_sigmas(self, m, P):


def _predict(m, P, f, Q, u, g_ev, g_cov):
"""Predict next mean and covariance under an additive-noise Gaussian filter
r"""Predict next mean and covariance under an additive-noise Gaussian filter

p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
where
Expand Down Expand Up @@ -117,7 +116,7 @@ def _predict(m, P, f, Q, u, g_ev, g_cov):


def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist):
"""Condition a Gaussian potential on a new observation with arbitrary
r"""Condition a Gaussian potential on a new observation with arbitrary
likelihood with given functions for conditional moments and make a
Gaussian approximation.
p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
Expand Down Expand Up @@ -172,7 +171,7 @@ def _step(carry, _):


def _statistical_linear_regression(mu, Sigma, m, S, C):
"""Return moment-matching affine coefficients and approximation noise variance
r"""Return moment-matching affine coefficients and approximation noise variance
given joint moments.

g(x) \approx Ax + b + e where e ~ N(0, Omega)
Expand Down
20 changes: 10 additions & 10 deletions dynamax/generalized_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from dynamax.nonlinear_gaussian_ssm.models import FnStateToState, FnStateAndInputToState
from dynamax.nonlinear_gaussian_ssm.models import FnStateToEmission, FnStateAndInputToEmission

FnStateToEmission2 = Callable[[Float[Array, "state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateToEmission2 = Callable[[Float[Array, " state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, "emission_dim emission_dim"]]

# emission distribution takes a mean vector and covariance matrix and returns a distribution
EmissionDistFn = Callable[ [Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]
EmissionDistFn = Callable[ [Float[Array, " state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]


class ParamsGGSSM(NamedTuple):
Expand All @@ -42,7 +42,7 @@ class ParamsGGSSM(NamedTuple):

"""

initial_mean: Float[Array, "state_dim"]
initial_mean: Float[Array, " state_dim"]
initial_covariance: Float[Array, "state_dim state_dim"]
dynamics_function: Union[FnStateToState, FnStateAndInputToState]
dynamics_covariance: Float[Array, "state_dim state_dim"]
Expand Down Expand Up @@ -97,15 +97,15 @@ def covariates_shape(self):
def initial_distribution(
self,
params: ParamsGGSSM,
inputs: Optional[Float[Array, "input_dim"]]=None
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
return MVN(params.initial_mean, params.initial_covariance)

def transition_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
f = params.dynamics_function
if inputs is None:
Expand All @@ -117,8 +117,8 @@ def transition_distribution(
def emission_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
h = params.emission_mean_function
R = params.emission_cov_function
Expand All @@ -128,4 +128,4 @@ def emission_distribution(
else:
mean = h(state, inputs)
cov = R(state, inputs)
return params.emission_dist(mean, cov)
return params.emission_dist(mean, cov)
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
plt.figure()
plt.imshow(jnp.vstack((states[None, :], most_likely_states[None, :])),
aspect="auto", interpolation='none', cmap="Greys")
plt.yticks([0.0, 1.0], ["$z$", "$\hat{z}$"])
plt.yticks([0.0, 1.0], ["$z$", r"$\hat{z}$"])
plt.xlabel("time")
plt.xlim(0, 500)


print("true log prob: ", hmm.marginal_log_prob(true_params, emissions, inputs=inputs))
print("test log prob: ", test_hmm.marginal_log_prob(params, emissions, inputs=inputs))

plt.show()
plt.show()
140 changes: 75 additions & 65 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax import vmap
from jax import jit
from functools import partial

from typing import Callable, Optional, Tuple, Union, NamedTuple
from jax import jit, lax, vmap
from jaxtyping import Int, Float, Array

from dynamax.types import Scalar, PRNGKey
from dynamax.types import IntScalar, Scalar

_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x

def get_trans_mat(transition_matrix, transition_fn, t):
def get_trans_mat(
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]],
t: IntScalar
) -> Float[Array, "num_states num_states"]:
if transition_fn is not None:
return transition_fn(t)
else:
if transition_matrix.ndim == 3: # (T,K,K)
elif transition_matrix is not None:
if transition_matrix.ndim == 3: # (T-1,K,K)
return transition_matrix[t]
else:
return transition_matrix
else:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

class HMMPosteriorFiltered(NamedTuple):
r"""Simple wrapper for properties of an HMM filtering posterior.
Expand Down Expand Up @@ -50,12 +54,12 @@ class HMMPosterior(NamedTuple):
filtered_probs: Float[Array, "num_timesteps num_states"]
predicted_probs: Float[Array, "num_timesteps num_states"]
smoothed_probs: Float[Array, "num_timesteps num_states"]
initial_probs: Float[Array, "num_states"]
trans_probs: Optional[Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]] = None
initial_probs: Float[Array, " num_states"]
trans_probs: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]] = None


def _normalize(u, axis=0, eps=1e-15):
def _normalize(u: Array, axis=0, eps=1e-15):
"""Normalizes the values within the axis in a way that they sum up to 1.

Args:
Expand Down Expand Up @@ -97,11 +101,11 @@ def _predict(probs, A):

@partial(jit, static_argnames=["transition_fn"])
def hmm_filter(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> HMMPosteriorFiltered:
r"""Forwards filtering

Expand Down Expand Up @@ -145,11 +149,11 @@ def _step(carry, t):

@partial(jit, static_argnames=["transition_fn"])
def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
) -> Tuple[Float, Float[Array, "num_timesteps num_states"]]:
transition_fn: Optional[Callable[[int], Float[Array, "num_states num_states"]]]= None
) -> Tuple[Scalar, Float[Array, "num_timesteps num_states"]]:
r"""Run the filter backwards in time. This is the second step of the forward-backward algorithm.

Transition matrix may be either 2D (if transition probabilities are fixed) or 3D
Expand Down Expand Up @@ -191,11 +195,11 @@ def _step(carry, t):

@partial(jit, static_argnames=["transition_fn"])
def hmm_two_filter_smoother(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using the two-filter
Expand Down Expand Up @@ -245,11 +249,11 @@ def hmm_two_filter_smoother(

@partial(jit, static_argnames=["transition_fn"])
def hmm_smoother(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using a general
Expand Down Expand Up @@ -325,12 +329,12 @@ def _step(carry, args):

@partial(jit, static_argnames=["transition_fn", "window_size"])
def hmm_fixed_lag_smoother(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
window_size: Int,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
window_size: int,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> HMMPosterior:
r"""Compute the smoothed state probabilities using the fixed-lag smoother.

Expand Down Expand Up @@ -439,12 +443,12 @@ def compute_posterior(filtered_probs, beta):

@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_mode(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
) -> Int[Array, "num_timesteps"]:
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> Int[Array, " num_timesteps"]:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.

Args:
Expand Down Expand Up @@ -486,13 +490,13 @@ def _forward_pass(state, best_next_state):

@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_sample(
rng: jr.PRNGKey,
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
key: Array,
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
) -> Int[Array, "num_timesteps"]:
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Tuple[Scalar, Int[Array, " num_timesteps"]]:
r"""Sample a latent sequence from the posterior.

Args:
Expand All @@ -515,7 +519,7 @@ def hmm_posterior_sample(
# Run the sampler backward in time
def _step(carry, args):
next_state = carry
t, rng, filtered_probs = args
t, subkey, filtered_probs = args

A = get_trans_mat(transition_matrix, transition_fn, t)

Expand All @@ -524,15 +528,15 @@ def _step(carry, args):
smoothed_probs /= smoothed_probs.sum()

# Sample current state
state = jr.choice(rng, a=num_states, p=smoothed_probs)
state = jr.choice(subkey, a=num_states, p=smoothed_probs)

return state, state

# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
keys = jr.split(key, num_timesteps)
last_state = jr.choice(keys[-1], a=num_states, p=filtered_probs[-1])
_, states = lax.scan(
_step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]),
_step, last_state, (jnp.arange(1, num_timesteps), keys[:-1], filtered_probs[:-1]),
reverse=True
)

Expand All @@ -544,12 +548,13 @@ def _compute_sum_transition_probs(
transition_matrix: Float[Array, "num_states num_states"],
hmm_posterior: HMMPosterior) -> Float[Array, "num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.

Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
"""

def _step(carry, args):
def _step(carry, args: Tuple[Array, Array, Array, Int[Array, ""]]):
filtered_probs, smoothed_probs_next, predicted_probs_next, t = args

# Get parameters for time t
Expand Down Expand Up @@ -580,11 +585,13 @@ def _step(carry, args):


def _compute_all_transition_probs(
transition_matrix: Float[Array, "num_timesteps num_states num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Float[Array, "num_timesteps num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.

Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
Expand All @@ -598,20 +605,21 @@ def _compute_probs(t):
A = get_trans_mat(transition_matrix, transition_fn, t)
return jnp.einsum('i,ij,j->ij', filtered_probs[t], A, relative_probs_next[t])

transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs)-1))
transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs)))
return transition_probs


# TODO: Consider alternative annotation for return type:
# Float[Array, "*num_timesteps num_states num_states"] I think this would allow multiple prepended dims.
# Float[Array, "#num_timesteps num_states num_states"] this might accept (1, sd, sd) but not (sd, sd).
# TODO: This is a candidate for @overload however at present I think we would need to use
# `@beartype.typing.overload` and beartype is currently not a core dependency.
# Support for `typing.overload` might change in the future:
# https://github.com/beartype/beartype/issues/54
def compute_transition_probs(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
) -> Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]:
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]:
r"""Compute the posterior marginal distributions $p(z_{t+1}, z_t \mid y_{1:T}, u_{1:T}, \theta)$.

Args:
Expand All @@ -622,8 +630,10 @@ def compute_transition_probs(
Returns:
array of smoothed transition probabilities.
"""
reduce_sum = transition_matrix is not None and transition_matrix.ndim == 2
if reduce_sum:
if transition_matrix is None and transition_fn is None:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

if transition_matrix is not None and transition_matrix.ndim == 2:
return _compute_sum_transition_probs(transition_matrix, hmm_posterior)
else:
return _compute_all_transition_probs(transition_matrix, hmm_posterior, transition_fn=transition_fn)
1 change: 0 additions & 1 deletion dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import itertools as it
import jax.numpy as jnp
import jax.random as jr
Expand Down
Loading
Loading