Skip to content

JAX multi-host data gather not working #22863

Answered by AshishKumar4
AshishKumar4 asked this question in Q&A
Discussion options

You must be logged in to vote

I think I might have figured out the issue. I was doing

subkey = jax.random.fold_in(subkey, jax.process_index())

as a way to make sure every process gets a unique random key. Turns out, doing this process_index() operation inside of the shard_mapped jitted function caused this issue. Removing this and instead supplying global indexes solved the issue, as follows:

global_device_indexes = jnp.arange(jax.device_count())
.
.
.
<pass the global device indexes array as argument to the shmapped train function>
.
.
[inside the jax shmapped jitted function]
subkey = jax.random.fold_in(subkey, local_device_index.reshape())

Finally figured this out!

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@AshishKumar4
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by AshishKumar4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants