Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make HMM learning work with variable length time series #99

Open
slinderman opened this issue Jul 14, 2022 · 6 comments
Open

Make HMM learning work with variable length time series #99

slinderman opened this issue Jul 14, 2022 · 6 comments
Assignees
Labels
help wanted Extra attention is needed

Comments

@slinderman
Copy link
Collaborator

I don't think the current hmm_fit_sgd function is using the length of the time series as we hoped. At least with the default loss function, the length is just scaling the loss. Really, we need to change marginal_log_prob to only compute the log probability of observation up to the specified length.

Following up on our slack conversation, I see two ways of doing that:

  1. Pad time series with nan's and modify _conditional_logliks to put zeros wherever the emission is nan. That way the hmm_filter will still compute the marginal log prob of just the observed data. I think trick should also leave the hmm_smoother computations unchanged. A cool added benefit of this is it would allow us to interpolate over chunks of missing data.

It would look like this:

       # Perform a nested vmap over timeteps and states
        f = lambda emission: \
            vmap(lambda state: \
                self.emission_distribution(state).log_prob(emission))(
                    jnp.arange(self.num_states)
                )

        lls = vmap(f)(emissions)
        return jnp.where(jnp.isnan(lls), 0, lls)

I tested this out and the only problem is that we can't take gradients back through this function wrt model parameters. They nan out because one of the paths through the where is nan. See jax-ml/jax#1052.

There's a somewhat clunky fix, which is to find the nan's first, replace them with a default value of the emissions, compute the log likelihoods, and then zero out the entries that were originally nan. That would look something like this:

        bad = jnp.any(jnp.isnan(emissions), axis=1)
        tmp = jnp.where(jnp.broadcast_to(bad[:, None], emissions.shape), 0.0, emissions)
        lls = vmap(f)(tmp)
        return jnp.where(jnp.broadcast_to(bad[:, None], lls.shape), 0.0, lls)

It's not the prettiest, but it works.

  1. Alternatively, we could pass the length of the time series to the underlying inference functions like hmm_filter. Then those functions would need to use a while loop to dynamically stop the message passing once the length has been reached. (I tried implementing this by calling filter on a dynamic slice of the data, but JAX barfed on that...) This approach is totally doable, but it would lead to lots of extra logic in the inference code.

I'm working on a demo of approach 1 right now. Will keep you posted!

@murphyk
Copy link
Member

murphyk commented Nov 14, 2022

See also #50

@slinderman
Copy link
Collaborator Author

Just commenting here to note that this request (or variants of it) has come up multiple times in the past few weeks. A simple change would be to make the low level inference code allow missing data, and then update the model based code when time allows.

The HMM inference code is simple enough: you can indicate missing data by passing zeros to the corresponding rows of log_likelihoods. The *GSSM code could handle missing data by similarly "zeroing out" potentials (making emission covariance ~ infinite) if the emissions are nan.

@murphyk
Copy link
Member

murphyk commented Feb 4, 2023

If we pass the valid length off each sequence, we can lax.scan only over that prefix.
missing data at random times could be handled with an if statement for conditional update, or local evidence vector which is all 1s for missing time steps.

@KeAWang
Copy link

KeAWang commented Mar 16, 2023

I actually have a fork of dynamax that handles missingness for the EKF (as well as allow time varying transitions and emissions): KeAWang@a991219. Though it's not for the HMM, I'm happy to open a PR for it

@slinderman
Copy link
Collaborator Author

Sure, that would be great Alex!

@SkepticRaven
Copy link

Hey, just seeing this library and thread and wonder how close this is to being ready. Our group works a lot with sparse annotations and it would be great to be able to generate hmm models that can work with missing data (both on the training side and the filter side).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants