-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for kmeans initialization with vmap #315
Comments
Hi @ghuckins thanks for showing interest in the library! Yes unfortunately the From what I can tell, updating the Depending on your precise use case there might be some reasonably straightforward work-arounds. For instance, it might be possible to use the I hope that provides some idea of the way forward, if you wanted to share more details about your use case I would be happy to try to give more precise advice. |
Hey Giles, thanks for the reply! I actually did find a Jax implementation of k-means online and am working on incorporating it into my own codebase; I could share it once I'm done, if that would be helpful. Just let me know! |
Hi @ghuckins . It would be great if you added your jax implementation of kmeans. |
Hi there,
When I try to use vmap to vectorize a function that includes a kmeans initialization, I get the following error:
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[11396,7])>with<BatchTrace(level=1/0)>
And here's the code that produces the error:
The error traces back to scikit-learn and Kmeans. The problem seems to be that scikit-learn uses numpy functions and not jax functions. Would it be possible to update hmm.initialize so that it could be use in vectorized functions?
Thanks!
The text was updated successfully, but these errors were encountered: