Skip to content

parallelization in jax when the data are more than the number of devices/cores #19452

Answered by jakevdp
yyang97 asked this question in Q&A
Discussion options

You must be logged in to vote

You can do something like this with shard_map, which is the more general successor of pmap. Note however that it requires operating on an axis size which is an integer multiple of the number of devices. For example:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

mesh = Mesh(mesh_utils.create_device_mesh(8,), axis_names=('i',))

key = jax.random.PRNGKey(0)
N = 800
x = jax.random.normal(key, (N,))

y = shard_map(jax.numpy.square, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@ForceBru
Comment options

Answer selected by yyang97
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants