Skip to content

Commit

Permalink
Fix LinearGaussianSSM.sample type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 12, 2024
1 parent 72bf8f5 commit a890d62
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,21 @@ def emission_distribution(
self,
params: ParamsLGSSM,
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> tfd.Distribution:
inputs = inputs if inputs is not None else jnp.zeros(self.input_dim)
mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs
if self.has_emissions_bias:
mean += params.emissions.bias
return MVN(mean, params.emissions.cov)

def sample(
self,
params: ParamsLGSSM,
key: PRNGKey,
num_timesteps: int,
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> PosteriorGSSMFiltered:
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]:
return lgssm_joint_sample(params, key, num_timesteps, inputs)

def marginal_log_prob(
Expand Down

0 comments on commit a890d62

Please sign in to comment.