JAX multi-host data gather not working #22863
-
I have been trying to build a multi-host training pipeline using jax and flax for training on TPU-v4-16/32+
Now, I actually did make some example code and on this its working-->
The output is printed for both pmap case and shard_map cases without issues. Mesh is simply defined as: The training code works for single host TPU-v4s, but doesn't for multi. And it happens if I switch to pmap as well. The error:
What am I doing wrong? I lost a lot of hairs on this :( |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Can you please paste what the error is? |
Beta Was this translation helpful? Give feedback.
-
I think I might have figured out the issue. I was doing
as a way to make sure every process gets a unique random key. Turns out, doing this process_index() operation inside of the shard_mapped jitted function caused this issue. Removing this and instead supplying global indexes solved the issue, as follows:
Finally figured this out! |
Beta Was this translation helpful? Give feedback.
I think I might have figured out the issue. I was doing
as a way to make sure every process gets a unique random key. Turns out, doing this process_index() operation inside of the shard_mapped jitted function caused this issue. Removing this and instead supplying global indexes solved the issue, as follows:
Finally figured this out!