Skip to content

Commit

Permalink
Add further type annotations to arhmm
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Oct 16, 2024
1 parent 4c84ddc commit 7e04c9e
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions dynamax/hidden_markov_model/models/arhmm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax.tree_util import tree_map
from jaxtyping import Float, Array
from jaxtyping import Int, Float, Array

from dynamax.hidden_markov_model.models.abstractions import HMM, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
Expand All @@ -11,7 +13,6 @@
from dynamax.types import Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from tensorflow_probability.substrates import jax as tfp
from typing import NamedTuple, Optional, Tuple, Union

tfd = tfp.distributions
tfb = tfp.bijectors
Expand All @@ -25,21 +26,22 @@ class ParamsLinearAutoregressiveHMM(NamedTuple):

class LinearAutoregressiveHMMEmissions(LinearRegressionHMMEmissions):
def __init__(self,
num_states,
emission_dim,
num_lags=1):
num_states: int,
emission_dim: int,
num_lags: int=1):
self.num_lags = num_lags
self.emission_dim = emission_dim
input_dim = num_lags * emission_dim
super().__init__(num_states, input_dim, emission_dim)

def initialize(self,
key=jr.PRNGKey(0),
method="prior",
emission_weights=None,
emission_biases=None,
emission_covariances=None,
emissions=None):
key: Array=jr.PRNGKey(0),
method: str="prior",
emission_weights: Optional[Float[Array, "num_states emission_dim input_dim"]]=None,
emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None,
emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None,
emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None
) -> Tuple[ParamsLinearRegressionHMMEmissions, ParamsLinearRegressionHMMEmissions]:
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
Expand Down Expand Up @@ -166,7 +168,7 @@ def sample(self,
key: Array,
num_timesteps: int,
prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None,
) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]:
) -> Tuple[Int[Array, " num_timesteps"], Float[Array, "num_timesteps emission_dim"]]:
r"""Sample states $z_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$.
Args:
Expand Down Expand Up @@ -211,7 +213,7 @@ def _step(carry, key):
def compute_inputs(self,
emissions: Float[Array, "num_timesteps emission_dim"],
prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None
) -> Float[Array, "num_timesteps emission_dim_times_num_lags"]:
) -> Float[Array, "num_timesteps {self.num_lags}*{self.emission_dim}"]:
r"""Helper function to compute the matrix of lagged emissions.
Args:
Expand Down

0 comments on commit 7e04c9e

Please sign in to comment.