Skip to content

Commit

Permalink
Fix type annotations in hmm parallel inference
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Oct 16, 2024
1 parent 0d4ca74 commit cfd3bfc
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions dynamax/hidden_markov_model/parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import jax.random as jr
from jax import lax, vmap, value_and_grad
from jaxtyping import Array, Float, Int
from typing import NamedTuple
from typing import NamedTuple, Tuple

from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered
from dynamax.types import Scalar

#---------------------------------------------------------------------------#
# Filtering #
Expand Down Expand Up @@ -91,7 +92,7 @@ def marginalize(m_ij, m_jk):
def hmm_smoother(initial_probs: Float[Array, " num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
) -> HMMPosteriorFiltered:
) -> HMMPosterior:
r"""Parallel implementation of HMM smoothing with `jax.lax.associative_scan`.
**Notes:**
Expand Down Expand Up @@ -138,36 +139,36 @@ def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods):
for each possible value of $z_i$.
"""

def _initialize_sampling_messages(rng, transition_matrix, filtered_probs):
def _initialize_sampling_messages(key, transition_matrix, filtered_probs):
"""Preprocess filtering output to construct input for sampling assocative scan."""

T, K = filtered_probs.shape
rngs = jr.split(rng, T)
keys = jr.split(key, T)

def _last_message(rng, probs):
state = jr.choice(rng, K, p=probs)
def _last_message(key, probs):
state = jr.choice(key, K, p=probs)
return jnp.repeat(state, K)

@vmap
def _generic_message(rng, probs):
def _generic_message(key, probs):
smoothed_probs = probs * transition_matrix.T
smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1)
return vmap(lambda p: jr.choice(rng, K, p=p))(smoothed_probs)
return vmap(lambda p: jr.choice(key, K, p=p))(smoothed_probs)

En = _last_message(rngs[-1], filtered_probs[-1])
Et = _generic_message(rngs[:-1], filtered_probs[:-1])
En = _last_message(keys[-1], filtered_probs[-1])
Et = _generic_message(keys[:-1], filtered_probs[:-1])
return jnp.concatenate([Et, En[None]])


def hmm_posterior_sample(rng: jr.PRNGKey,
def hmm_posterior_sample(key: Array,
initial_distribution: Float[Array, " num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
) -> Int[Array, " num_timesteps"]:
) -> Tuple[Scalar, Int[Array, " num_timesteps"]]:
r"""Sample a sequence of hidden states from the posterior.
Args:
rng: random number generator
key: random number generator
initial_distribution: $p(z_1 \mid u_1, \theta)$
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
Expand All @@ -176,8 +177,6 @@ def hmm_posterior_sample(rng: jr.PRNGKey,
log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$
states: sequence of hidden states $z_{1:T}$
"""
T, K = log_likelihoods.shape

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods)
log_normalizer = post.marginal_loglik
Expand All @@ -187,7 +186,7 @@ def hmm_posterior_sample(rng: jr.PRNGKey,
def _operator(E_jk, E_ij):
return jnp.take(E_ij, E_jk)

initial_messages = _initialize_sampling_messages(rng, transition_matrix, filtered_probs)
initial_messages = _initialize_sampling_messages(key, transition_matrix, filtered_probs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)
states = final_messages[:,0]
return log_normalizer, states

0 comments on commit cfd3bfc

Please sign in to comment.