diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index b499f166..4ff4733b 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -14,7 +14,7 @@ from typing import NamedTuple, Optional, Union, Tuple from dynamax.utils.utils import psd_solve, symmetrize from dynamax.parameters import ParameterProperties -from dynamax.types import PRNGKey, Scalar +from dynamax.types import PRNGKeyT, Scalar class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution @@ -363,7 +363,7 @@ def wrapper(*args, **kwargs): def lgssm_joint_sample( params: ParamsLGSSM, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None )-> Tuple[Float[Array, "num_timesteps state_dim"], @@ -559,7 +559,7 @@ def _step(carry, args): def lgssm_posterior_sample( - key: PRNGKey, + key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None, diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 2d6e78b8..dc340f26 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -15,7 +15,7 @@ from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed from dynamax.parameters import ParameterProperties -from dynamax.types import PRNGKey, Scalar +from dynamax.types import PRNGKeyT, Scalar from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW from dynamax.utils.distributions import NormalInverseWishart as NIW @@ -88,7 +88,7 @@ def inputs_shape(self): def initialize( self, - key: PRNGKey =jr.PRNGKey(0), + key: PRNGKeyT =jr.PRNGKey(0), initial_mean: Optional[Float[Array, " state_dim"]]=None, initial_covariance=None, dynamics_weights=None, @@ -203,7 +203,7 @@ def emission_distribution( def sample( self, params: ParamsLGSSM, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: @@ -236,7 +236,7 @@ def smoother( def posterior_sample( self, - key: PRNGKey, + key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None @@ -501,7 +501,7 @@ def m_step( def fit_blocked_gibbs( self, - key: PRNGKey, + key: PRNGKeyT, initial_params: ParamsLGSSM, sample_size: int, emissions: Float[Array, "nbatch ntime emission_dim"], diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 84374664..a212f3b2 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -34,7 +34,7 @@ from jax import vmap, lax from jaxtyping import Array, Float from typing import NamedTuple -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT from functools import partial import warnings @@ -354,7 +354,7 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian def lgssm_posterior_sample( - key: PRNGKey, + key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"] ) -> Float[Array, "ntime state_dim"]: diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py index d936ed45..2577a610 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py @@ -9,7 +9,7 @@ from dynamax.utils.utils import psd_solve, symmetrize from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT # Helper functions _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x @@ -258,7 +258,7 @@ def _step(carry, args): def extended_kalman_posterior_sample( - key: PRNGKey, + key: PRNGKeyT, params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None diff --git a/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py b/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py index ef6794ab..52af8331 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py @@ -15,7 +15,7 @@ from dynamax.parameters import ParameterProperties from dynamax.ssm import SSM from dynamax.utils.bijectors import RealToPSDBijector -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT tfd = tfp.distributions @@ -37,7 +37,7 @@ def lgssm_to_nlgssm(params: ParamsLGSSM) -> ParamsNLGSSM: def random_lgssm_args( - key: Union[int, PRNGKey] = 0, + key: Union[int, PRNGKeyT] = 0, num_timesteps: int = 15, state_dim: int = 4, emission_dim: int = 2 diff --git a/dynamax/slds/inference.py b/dynamax/slds/inference.py index ddd2b3e4..6f7ea39e 100644 --- a/dynamax/slds/inference.py +++ b/dynamax/slds/inference.py @@ -6,7 +6,7 @@ from jaxtyping import Array, Float, Int from typing import NamedTuple, Optional from dynamax.utils.utils import psd_solve -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT class DiscreteParamsSLDS(NamedTuple): initial_distribution: Float[Array, " num_states"] @@ -164,7 +164,7 @@ def rbpfilter( num_particles: int, params: ParamsSLDS, emissions: Float[Array, "ntime emission_dim"], - key: PRNGKey = jr.PRNGKey(0), + key: PRNGKeyT = jr.PRNGKey(0), inputs: Optional[Float[Array, "ntime input_dim"]] = None, ess_threshold: float = 0.5 ): @@ -253,7 +253,7 @@ def rbpfilter_optimal( num_particles: int, params: ParamsSLDS, emissions: Float[Array, "ntime emission_dim"], - key: PRNGKey = jr.PRNGKey(0), + key: PRNGKeyT = jr.PRNGKey(0), inputs: Optional[Float[Array, "ntime input_dim"]]=None ): ''' diff --git a/dynamax/slds/models.py b/dynamax/slds/models.py index 67858fa5..52a0a6cf 100644 --- a/dynamax/slds/models.py +++ b/dynamax/slds/models.py @@ -9,7 +9,7 @@ from dynamax.ssm import SSM from dynamax.slds.inference import ParamsSLDS -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT class SLDS(SSM): @@ -71,7 +71,7 @@ def emission_distribution( def sample( self, params: ParamsSLDS, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None ) -> Tuple[Float[Array, "num_timesteps state_dim"], diff --git a/dynamax/ssm.py b/dynamax/ssm.py index e55f751f..31cd574f 100644 --- a/dynamax/ssm.py +++ b/dynamax/ssm.py @@ -14,7 +14,7 @@ from dynamax.parameters import to_unconstrained, from_unconstrained from dynamax.parameters import ParameterSet, PropertySet -from dynamax.types import PRNGKey, Scalar +from dynamax.types import PRNGKeyT, Scalar from dynamax.utils.optimize import run_sgd from dynamax.utils.utils import ensure_array_has_batch_dim @@ -173,7 +173,7 @@ def inputs_shape(self) -> Optional[Tuple[int]]: def sample( self, params: ParameterSet, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None ) -> Tuple[Float[Array, "num_timesteps state_dim"], @@ -414,7 +414,7 @@ def fit_sgd( batch_size: int=1, num_epochs: int=50, shuffle: bool=False, - key: PRNGKey=jr.PRNGKey(0) + key: PRNGKeyT=jr.PRNGKey(0) ) -> Tuple[ParameterSet, Float[Array, " niter"]]: r"""Compute parameter MLE/ MAP estimate using Stochastic Gradient Descent (SGD). diff --git a/dynamax/types.py b/dynamax/types.py index 4847471f..d5fd26a2 100644 --- a/dynamax/types.py +++ b/dynamax/types.py @@ -1,8 +1,6 @@ from typing import Union from jaxtyping import Array, Float -import jax._src.random as prng - -PRNGKey = prng.KeyArray +PRNGKeyT = Array Scalar = Union[float, Float[Array, ""]] # python float or scalar jax device array with dtype float