diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 7b05c788..f5b5f019 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -191,21 +191,21 @@ def emission_distribution( self, params: ParamsLGSSM, state: Float[Array, " state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs if self.has_emissions_bias: mean += params.emissions.bias return MVN(mean, params.emissions.cov) - + def sample( self, params: ParamsLGSSM, key: PRNGKey, num_timesteps: int, - inputs: Optional[Float[Array, "ntime input_dim"]] = None - ) -> PosteriorGSSMFiltered: + inputs: Optional[Float[Array, "ntime input_dim"]] = None, + ) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: return lgssm_joint_sample(params, key, num_timesteps, inputs) def marginal_log_prob(