Skip to content

Commit

Permalink
Fix jr.PRNGKey type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Oct 16, 2024
1 parent 20a1df7 commit ca9a75c
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 27 deletions.
16 changes: 7 additions & 9 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax import vmap
from jax import jit
from jax import jit, lax, vmap
from functools import partial

from typing import Callable, Optional, Tuple, Union, NamedTuple
Expand Down Expand Up @@ -486,7 +484,7 @@ def _forward_pass(state, best_next_state):

@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_sample(
rng: jr.PRNGKey,
key: Array,
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
Expand Down Expand Up @@ -515,7 +513,7 @@ def hmm_posterior_sample(
# Run the sampler backward in time
def _step(carry, args):
next_state = carry
t, rng, filtered_probs = args
t, subkey, filtered_probs = args

A = get_trans_mat(transition_matrix, transition_fn, t)

Expand All @@ -524,15 +522,15 @@ def _step(carry, args):
smoothed_probs /= smoothed_probs.sum()

# Sample current state
state = jr.choice(rng, a=num_states, p=smoothed_probs)
state = jr.choice(subkey, a=num_states, p=smoothed_probs)

return state, state

# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
keys = jr.split(key, num_timesteps)
last_state = jr.choice(keys[-1], a=num_states, p=filtered_probs[-1])
_, states = lax.scan(
_step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]),
_step, last_state, (jnp.arange(1, num_timesteps), keys[:-1], filtered_probs[:-1]),
reverse=True
)

Expand Down
6 changes: 3 additions & 3 deletions dynamax/hidden_markov_model/models/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def distribution(self,

@abstractmethod
def initialize(self,
key: jr.PRNGKey=None,
key: Optional[Array]=None,
method: str="prior",
**kwargs
) -> Tuple[ParameterSet, PropertySet]:
Expand Down Expand Up @@ -212,7 +212,7 @@ def distribution(self,

@abstractmethod
def initialize(self,
key: jr.PRNGKey=None,
key: Optional[Array]=None,
method: str="prior",
**kwargs
) -> Tuple[ParameterSet, PropertySet]:
Expand Down Expand Up @@ -368,7 +368,7 @@ def distribution(self,

@abstractmethod
def initialize(self,
key: jr.PRNGKey=None,
key: Optional[Array]=None,
method: str="prior",
**kwargs
) -> Tuple[ParameterSet, PropertySet]:
Expand Down
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/models/arhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def inputs_shape(self):
return (self.num_lags * self.emission_dim,)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down Expand Up @@ -163,7 +163,7 @@ def initialize(self,

def sample(self,
params: HMMParameterSet,
key: jr.PRNGKey,
key: Array,
num_timesteps: int,
prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None,
) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]:
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/bernoulli_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self, num_states: int,
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/categorical_glm_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def inputs_shape(self):
return (self.input_dim,)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/categorical_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(self, num_states: int,
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/gamma_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self,
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
10 changes: 5 additions & 5 deletions dynamax/hidden_markov_model/models/gaussian_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def __init__(self, num_states: int,
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down Expand Up @@ -711,7 +711,7 @@ def __init__(self, num_states: int,

super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
def initialize(self, key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down Expand Up @@ -809,7 +809,7 @@ def __init__(self, num_states: int,

super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
def initialize(self, key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down Expand Up @@ -899,7 +899,7 @@ def __init__(self, num_states: int,
emission_prior_extra_df=emission_prior_extra_df)
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
def initialize(self, key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down Expand Up @@ -995,7 +995,7 @@ def __init__(self, num_states: int,
m_step_num_iters=m_step_num_iters)
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0),
def initialize(self, key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/models/gmm_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(self,
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down Expand Up @@ -463,7 +463,7 @@ def __init__(self,


def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/linreg_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def inputs_shape(self):
return (self.input_dim,)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/logreg_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def inputs_shape(self):
return (self.inputs_dim,)

def initialize(self,
key: jr.PRNGKey=jr.PRNGKey(0),
key: Array=jr.PRNGKey(0),
method: str="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down

0 comments on commit ca9a75c

Please sign in to comment.