Replies: 1 comment 2 replies
-
I can't quite tell what you are trying to accomplish from that snippet of code. But one red flag is the fact that you are taking the You can split the key safely inside the vmap function but you should hardcode the number of splits. The vmap over that will take care of producing the number of RNG keys you need for each example in the batch. Maybe you can share a little more of what you are trying to do so we can help more? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am trying to use mlx.core.vmap in order to vectorise an implementation of the gamma distribution and part of that I want to vectorise splitting the PRNG keys.
I have the following code simplified to show the error I am getting:
The error I am receiving when I run this code is the following:
I do not know what the problem is and why key doesn't seem to be passed to function _split
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions