Skip to content

Commit

Permalink
Update type annotations in hmm base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Oct 16, 2024
1 parent 62df3ba commit 45134e8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 33 deletions.
74 changes: 44 additions & 30 deletions dynamax/hidden_markov_model/models/abstractions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import abstractmethod, ABC
from typing import Any, Optional, Tuple, runtime_checkable, Union
from typing_extensions import Protocol
from dynamax.ssm import SSM
from dynamax.types import Scalar
from dynamax.types import IntScalar, Scalar
from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.parameters import ParameterSet, PropertySet
from dynamax.hidden_markov_model.inference import HMMPosterior
Expand All @@ -11,14 +13,12 @@
from dynamax.utils.optimize import run_gradient_descent
from dynamax.utils.utils import pytree_slice
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
from jax.tree_util import tree_map
from jaxtyping import Float, Array, PyTree
from jaxtyping import Float, Array, PyTree, Real
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
from typing import Any, Optional, Tuple, runtime_checkable
from typing_extensions import Protocol



@runtime_checkable
Expand Down Expand Up @@ -113,14 +113,14 @@ def log_prior(self, params: ParameterSet) -> Scalar:
"""
raise NotImplementedError

def _compute_initial_probs(self, params, inputs=None):
return self.initial_distribution(params, inputs).probs_parameter()
def _compute_initial_probs(self, params, inputs:Optional[Array] = None):
return self.distribution(params, inputs).probs_parameter()

def collect_suff_stats(self,
params: ParameterSet,
posterior: HMMPosterior,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> PyTree:
) -> Tuple[Float[Array, " num_states"], Optional[Float[Array, " input_dim"]]]:
"""Collect sufficient statistics for updating the initial distribution parameters.
Args:
Expand Down Expand Up @@ -152,7 +152,7 @@ def m_step(self,
batch_stats: PyTree,
m_step_state: Any,
scale: float=1.0
) -> ParameterSet:
) -> Tuple[ParameterSet, Any]:
"""Perform an M-step on the initial distribution parameters.
Args:
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(self,
@abstractmethod
def distribution(self,
params: ParameterSet,
state: int,
state: IntScalar,
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
"""Return a distribution over the next latent state
Expand Down Expand Up @@ -255,7 +255,7 @@ def log_prior(self, params: ParameterSet) -> Scalar:
"""
raise NotImplementedError

def _compute_transition_matrices(self, params, inputs=None):
def _compute_transition_matrices(self, params, inputs:Optional[Array] = None):
if inputs is not None:
f = lambda inpt: \
vmap(lambda state: \
Expand All @@ -271,7 +271,7 @@ def collect_suff_stats(self,
params: ParameterSet,
posterior: HMMPosterior,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> PyTree:
) -> Tuple[Float[Array, "..."], Optional[Float[Array, "num_timesteps-1 input_dim"]]]:
"""Collect sufficient statistics for updating the transition distribution parameters.
Args:
Expand Down Expand Up @@ -301,7 +301,7 @@ def m_step(self,
batch_stats: PyTree,
m_step_state: Any,
scale: float=1.0
) -> ParameterSet:
) -> Tuple[ParameterSet, Any]:
"""Perform an M-step on the transition distribution parameters.
Args:
Expand Down Expand Up @@ -367,7 +367,7 @@ def emission_shape(self) -> Tuple[int]:
@abstractmethod
def distribution(self,
params: ParameterSet,
state: int,
state: IntScalar,
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
"""Return a distribution over the emission
Expand Down Expand Up @@ -411,7 +411,7 @@ def log_prior(self, params: ParameterSet) -> Scalar:
"""
raise NotImplementedError

def _compute_conditional_logliks(self, params, emissions, inputs=None):
def _compute_conditional_logliks(self, params, emissions, inputs:Optional[Array] = None):
# Compute the log probability for each time step by
# performing a nested vmap over emission time steps and states.
f = lambda emission, inpt: \
Expand All @@ -422,9 +422,12 @@ def _compute_conditional_logliks(self, params, emissions, inputs=None):
def collect_suff_stats(self,
params: ParameterSet,
posterior: HMMPosterior,
emissions: Float[Array, "num_timesteps emission_dim"],
emissions: Union[Real[Array, "num_timesteps emission_dim"],
Real[Array, " num_timesteps"]],
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> PyTree:
) -> Tuple[Float[Array, "num_timesteps num_states"],
Union[Real[Array, "num_timesteps emission_dim"], Real[Array, " num_timesteps"]],
Optional[Float[Array, "num_timesteps input_dim"]]]:
"""Collect sufficient statistics for updating the emission distribution parameters.
Args:
Expand Down Expand Up @@ -455,7 +458,7 @@ def m_step(self,
batch_stats: PyTree,
m_step_state: Any,
scale: float=1.0
) -> ParameterSet:
) -> Tuple[ParameterSet, Any]:
"""Perform an M-step on the emission distribution parameters.
Args:
Expand Down Expand Up @@ -549,43 +552,48 @@ def __init__(self,
def emission_shape(self):
return self.emission_component.emission_shape

def initial_distribution(self, params, inputs=None):
def initial_distribution(self, params: HMMParameterSet, inputs:Optional[Array] = None) -> tfd.Distribution:
return self.initial_component.distribution(params.initial, inputs=inputs)

def transition_distribution(self, params, state, inputs=None):
def transition_distribution(self, params: HMMParameterSet, state: IntScalar, inputs:Optional[Array] = None) -> tfd.Distribution:
return self.transition_component.distribution(params.transitions, state, inputs=inputs)

def emission_distribution(self, params, state, inputs=None):
def emission_distribution(self, params: HMMParameterSet, state: IntScalar, inputs:Optional[Array] = None):
return self.emission_component.distribution(params.emissions, state, inputs=inputs)

def log_prior(self, params):
def log_prior(self, params: HMMParameterSet) -> Scalar:
lp = self.initial_component.log_prior(params.initial)
lp += self.transition_component.log_prior(params.transitions)
lp += self.emission_component.log_prior(params.emissions)
return lp

# The inference functions all need the same arguments
def _inference_args(self, params, emissions, inputs):
def _inference_args(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]):
return (self.initial_component._compute_initial_probs(params.initial, inputs),
self.transition_component._compute_transition_matrices(params.transitions, inputs),
self.emission_component._compute_conditional_logliks(params.emissions, emissions, inputs))

# Convenience wrappers for the inference code
def marginal_log_prob(self, params, emissions, inputs=None):
def marginal_log_prob(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None):
post = hmm_filter(*self._inference_args(params, emissions, inputs))
return post.marginal_loglik

def most_likely_states(self, params, emissions, inputs=None):
def most_likely_states(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None):
return hmm_posterior_mode(*self._inference_args(params, emissions, inputs))

def filter(self, params, emissions, inputs=None):
def filter(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None):
return hmm_filter(*self._inference_args(params, emissions, inputs))

def smoother(self, params, emissions, inputs=None):
def smoother(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None):
return hmm_smoother(*self._inference_args(params, emissions, inputs))

# Expectation-maximization (EM) code
def e_step(self, params, emissions, inputs=None):
def e_step(
self,
params: HMMParameterSet,
emissions: Array,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> Tuple[PyTree, Scalar]:
"""The E-step computes expected sufficient statistics under the
posterior. In the generic case, we simply return the posterior itself.
"""
Expand All @@ -597,7 +605,7 @@ def e_step(self, params, emissions, inputs=None):
emission_stats = self.emission_component.collect_suff_stats(params.emissions, posterior, emissions, inputs)
return (initial_stats, transition_stats, emission_stats), posterior.marginal_loglik

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self, params: HMMParameterSet, props: HMMPropertySet):
"""Initialize any required state for the M step.
For example, this might include the optimizer state for Adam.
Expand All @@ -607,7 +615,13 @@ def initialize_m_step_state(self, params, props):
emissions_m_step_state = self.emission_component.initialize_m_step_state(params.emissions, props.emissions)
return initial_m_step_state, transitions_m_step_state, emissions_m_step_state

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: HMMParameterSet,
props: HMMPropertySet,
batch_stats: PyTree,
m_step_state: Any
) -> Tuple[HMMParameterSet, Any]:
batch_initial_stats, batch_transition_stats, batch_emission_stats = batch_stats
initial_m_step_state, transitions_m_step_state, emissions_m_step_state = m_step_state

Expand Down
6 changes: 3 additions & 3 deletions dynamax/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.random as jr
from jax import jit, lax, vmap
from jax.tree_util import tree_map
from jaxtyping import Float, Array
from jaxtyping import Array, Float, Real
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
from typing import Optional, Union, Tuple, Any, runtime_checkable
Expand Down Expand Up @@ -351,8 +351,8 @@ def fit_em(
self,
params: ParameterSet,
props: PropertySet,
emissions: Union[Float[Array, "num_timesteps emission_dim"],
Float[Array, "num_batches num_timesteps emission_dim"]],
emissions: Union[Real[Array, "num_timesteps emission_dim"],
Real[Array, "num_batches num_timesteps emission_dim"]],
inputs: Optional[Union[Float[Array, "num_timesteps input_dim"],
Float[Array, "num_batches num_timesteps input_dim"]]]=None,
num_iters: int=50,
Expand Down

0 comments on commit 45134e8

Please sign in to comment.