Skip to content

Should I be using jax.pmap? #22645

Closed Answered by yashk2810
skushnir123 asked this question in Q&A
Discussion options

You must be logged in to vote

There are 2 new APIs for it:

  1. Use jax.jit and jax.Sharding APIs for compiler based parallelism: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

  2. If you then want to write collectives inside that jitted function then you can use shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html

Both jit and shard_map compose nicely with each other and they are the recommended APIs for you to do parallelism in JAX.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by skushnir123
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants