Replies: 1 comment
-
I'd probably check the output of pjit(fn).lower(args).compile().as_text() and see if it has what you're looking for (perhaps a ppermute somewhere?). If it's not behaving properly you could check out shard_map, which lets you get more manual control over parallelism within a function you want to jit/pjit. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I need to implement a certain sequence parallelism, where each core works on each segment of a long sequence. This requires each core to send its input to the core that works on the next segment for local attention, all at the same time in parallel. This can be described as a process like the following:
In principle, the communication happens only from core_n to core_n+1, which is cheap. But can this also be true in practice? I haven't seen any example of a process like this yet, so I'm concerned if this behavior is not supported.
Beta Was this translation helpful? Give feedback.
All reactions