Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for kmeans initialization with vmap #315

Open
ghuckins opened this issue May 16, 2023 · 3 comments · May be fixed by #371
Open

Support for kmeans initialization with vmap #315

ghuckins opened this issue May 16, 2023 · 3 comments · May be fixed by #371

Comments

@ghuckins
Copy link

ghuckins commented May 16, 2023

Hi there,

When I try to use vmap to vectorize a function that includes a kmeans initialization, I get the following error:

jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[11396,7])>with<BatchTrace(level=1/0)>

And here's the code that produces the error:

    hmm = GaussianHMM(latdim, obsdim)
    data1 = jnp.array(data1)
    data2 = jnp.array(data2)
    data1_train = jnp.stack([jnp.concatenate([data1[:i], data1[i+1:]]) for i in range(len(data1))])
    data2_train = jnp.stack([jnp.concatenate([data2[:i], data2[i+1:]]) for i in range(len(data2))])

    base_params1, props1 = hmm.initialize(key=get_key(), method="kmeans", emissions=data1[:length,:,:])
    params1, _ = hmm.fit_em(base_params1, props1, data1[:length,:,:], num_iters=100, verbose=False)
    base_params2, props2 = hmm.initialize(key=get_key(), method="kmeans", emissions=data2[:length,:,:])
    params2, _ = hmm.fit_em(base_params2, props2, data2[:length,:,:], num_iters=100, verbose=False)
    def _fit_fold(train, test, params):
        base_params, props = hmm.initialize(key=get_key(), method="kmeans", emissions=train[:length,:,:])
        fit_params, _ = hmm.fit_em(base_params, props, train[:length,:,:], num_iters=100, verbose=False)
        return (hmm.marginal_log_prob(fit_params, test) > hmm.marginal_log_prob(params, test)).astype(int)

    correct1 = jnp.sum(vmap(_fit_fold, in_axes = [0,0,None])(data1_train,data1,params2))

The error traces back to scikit-learn and Kmeans. The problem seems to be that scikit-learn uses numpy functions and not jax functions. Would it be possible to update hmm.initialize so that it could be use in vectorized functions?

Thanks!

@gileshd
Copy link
Collaborator

gileshd commented May 23, 2023

Hi @ghuckins thanks for showing interest in the library! Yes unfortunately the sklearn bits of code won't naturally play nice with lots of jax's tools.

From what I can tell, updating the "kmeans" option in hmm.initialize to use a jax compatible implementation of the kmeans algorithm would involve writing, testing, and maintaining our own jax kmeans implementation which might be outside the scope of this library unfortunately (unless there is a really great demand for it).

Depending on your precise use case there might be some reasonably straightforward work-arounds. For instance, it might be possible to use the sklearn "kmeans" intialization to generate the appropriate initial parameters for your hmms which could then be passed as input into a function which could be vmapped over your data.

I hope that provides some idea of the way forward, if you wanted to share more details about your use case I would be happy to try to give more precise advice.

@ghuckins
Copy link
Author

Hey Giles, thanks for the reply! I actually did find a Jax implementation of k-means online and am working on incorporating it into my own codebase; I could share it once I'm done, if that would be helpful. Just let me know!

@murphyk
Copy link
Member

murphyk commented May 25, 2023

Hi @ghuckins . It would be great if you added your jax implementation of kmeans.

@gileshd gileshd linked a pull request Jul 21, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants