Skip to content

Commit

Permalink
Rename and change PRNGKey Type
Browse files Browse the repository at this point in the history
- Rename to PRNGKeyT to differentiate from `jax.random` function
- Change type to Array
  - see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys
  • Loading branch information
gileshd committed Sep 12, 2024
1 parent 1ded8e9 commit c9109e8
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 25 deletions.
6 changes: 3 additions & 3 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import NamedTuple, Optional, Union, Tuple
from dynamax.utils.utils import psd_solve, symmetrize
from dynamax.parameters import ParameterProperties
from dynamax.types import PRNGKey, Scalar
from dynamax.types import PRNGKeyT, Scalar

class ParamsLGSSMInitial(NamedTuple):
r"""Parameters of the initial distribution
Expand Down Expand Up @@ -363,7 +363,7 @@ def wrapper(*args, **kwargs):

def lgssm_joint_sample(
params: ParamsLGSSM,
key: PRNGKey,
key: PRNGKeyT,
num_timesteps: int,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
)-> Tuple[Float[Array, "num_timesteps state_dim"],
Expand Down Expand Up @@ -559,7 +559,7 @@ def _step(carry, args):


def lgssm_posterior_sample(
key: PRNGKey,
key: PRNGKeyT,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None,
Expand Down
10 changes: 5 additions & 5 deletions dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.parameters import ParameterProperties
from dynamax.types import PRNGKey, Scalar
from dynamax.types import PRNGKeyT, Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
from dynamax.utils.distributions import NormalInverseWishart as NIW
Expand Down Expand Up @@ -88,7 +88,7 @@ def inputs_shape(self):

def initialize(
self,
key: PRNGKey =jr.PRNGKey(0),
key: PRNGKeyT =jr.PRNGKey(0),
initial_mean: Optional[Float[Array, " state_dim"]]=None,
initial_covariance=None,
dynamics_weights=None,
Expand Down Expand Up @@ -203,7 +203,7 @@ def emission_distribution(
def sample(
self,
params: ParamsLGSSM,
key: PRNGKey,
key: PRNGKeyT,
num_timesteps: int,
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]:
Expand Down Expand Up @@ -236,7 +236,7 @@ def smoother(

def posterior_sample(
self,
key: PRNGKey,
key: PRNGKeyT,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
Expand Down Expand Up @@ -501,7 +501,7 @@ def m_step(

def fit_blocked_gibbs(
self,
key: PRNGKey,
key: PRNGKeyT,
initial_params: ParamsLGSSM,
sample_size: int,
emissions: Float[Array, "nbatch ntime emission_dim"],
Expand Down
4 changes: 2 additions & 2 deletions dynamax/linear_gaussian_ssm/parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from jax import vmap, lax
from jaxtyping import Array, Float
from typing import NamedTuple
from dynamax.types import PRNGKey
from dynamax.types import PRNGKeyT
from functools import partial
import warnings

Expand Down Expand Up @@ -354,7 +354,7 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian


def lgssm_posterior_sample(
key: PRNGKey,
key: PRNGKeyT,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
) -> Float[Array, "ntime state_dim"]:
Expand Down
4 changes: 2 additions & 2 deletions dynamax/nonlinear_gaussian_ssm/inference_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dynamax.utils.utils import psd_solve, symmetrize
from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.types import PRNGKey
from dynamax.types import PRNGKeyT

# Helper functions
_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
Expand Down Expand Up @@ -258,7 +258,7 @@ def _step(carry, args):


def extended_kalman_posterior_sample(
key: PRNGKey,
key: PRNGKeyT,
params: ParamsNLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]] = None
Expand Down
4 changes: 2 additions & 2 deletions dynamax/nonlinear_gaussian_ssm/inference_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dynamax.parameters import ParameterProperties
from dynamax.ssm import SSM
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.types import PRNGKey
from dynamax.types import PRNGKeyT


tfd = tfp.distributions
Expand All @@ -37,7 +37,7 @@ def lgssm_to_nlgssm(params: ParamsLGSSM) -> ParamsNLGSSM:


def random_lgssm_args(
key: Union[int, PRNGKey] = 0,
key: Union[int, PRNGKeyT] = 0,
num_timesteps: int = 15,
state_dim: int = 4,
emission_dim: int = 2
Expand Down
6 changes: 3 additions & 3 deletions dynamax/slds/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jaxtyping import Array, Float, Int
from typing import NamedTuple, Optional
from dynamax.utils.utils import psd_solve
from dynamax.types import PRNGKey
from dynamax.types import PRNGKeyT

class DiscreteParamsSLDS(NamedTuple):
initial_distribution: Float[Array, " num_states"]
Expand Down Expand Up @@ -164,7 +164,7 @@ def rbpfilter(
num_particles: int,
params: ParamsSLDS,
emissions: Float[Array, "ntime emission_dim"],
key: PRNGKey = jr.PRNGKey(0),
key: PRNGKeyT = jr.PRNGKey(0),
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
ess_threshold: float = 0.5
):
Expand Down Expand Up @@ -253,7 +253,7 @@ def rbpfilter_optimal(
num_particles: int,
params: ParamsSLDS,
emissions: Float[Array, "ntime emission_dim"],
key: PRNGKey = jr.PRNGKey(0),
key: PRNGKeyT = jr.PRNGKey(0),
inputs: Optional[Float[Array, "ntime input_dim"]]=None
):
'''
Expand Down
4 changes: 2 additions & 2 deletions dynamax/slds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dynamax.ssm import SSM
from dynamax.slds.inference import ParamsSLDS
from dynamax.types import PRNGKey
from dynamax.types import PRNGKeyT

class SLDS(SSM):

Expand Down Expand Up @@ -71,7 +71,7 @@ def emission_distribution(
def sample(
self,
params: ParamsSLDS,
key: PRNGKey,
key: PRNGKeyT,
num_timesteps: int,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> Tuple[Float[Array, "num_timesteps state_dim"],
Expand Down
6 changes: 3 additions & 3 deletions dynamax/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.parameters import ParameterSet, PropertySet
from dynamax.types import PRNGKey, Scalar
from dynamax.types import PRNGKeyT, Scalar
from dynamax.utils.optimize import run_sgd
from dynamax.utils.utils import ensure_array_has_batch_dim

Expand Down Expand Up @@ -173,7 +173,7 @@ def inputs_shape(self) -> Optional[Tuple[int]]:
def sample(
self,
params: ParameterSet,
key: PRNGKey,
key: PRNGKeyT,
num_timesteps: int,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> Tuple[Float[Array, "num_timesteps state_dim"],
Expand Down Expand Up @@ -414,7 +414,7 @@ def fit_sgd(
batch_size: int=1,
num_epochs: int=50,
shuffle: bool=False,
key: PRNGKey=jr.PRNGKey(0)
key: PRNGKeyT=jr.PRNGKey(0)
) -> Tuple[ParameterSet, Float[Array, " niter"]]:
r"""Compute parameter MLE/ MAP estimate using Stochastic Gradient Descent (SGD).
Expand Down
4 changes: 1 addition & 3 deletions dynamax/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Union
from jaxtyping import Array, Float
import jax._src.random as prng


PRNGKey = prng.KeyArray
PRNGKeyT = Array

Scalar = Union[float, Float[Array, ""]] # python float or scalar jax device array with dtype float

0 comments on commit c9109e8

Please sign in to comment.