Should I be using jax.pmap? #22645
-
I have a numerical integration application which I am using Jax to speed up. The use case is I want to run this integration on many samples in parallel (using a different GPU for each samples). I have seen conflicting things on jax.pmap. I saw somewhere that it is being deprecated in favor of xmap, but all of xmap's docs have been taken down. What is the recommended way to parallelize computation across devices right now? Is it safe to use pmap without being scared of deprecation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
There are 2 new APIs for it:
Both |
Beta Was this translation helpful? Give feedback.
There are 2 new APIs for it:
Use
jax.jit
andjax.Sharding
APIs for compiler based parallelism: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.htmlIf you then want to write collectives inside that jitted function then you can use shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
Both
jit
andshard_map
compose nicely with each other and they are the recommended APIs for you to do parallelism in JAX.