-
Hi jax lovers and developers, I am trying parallel computing in jax. However, in my case, I have 100 data but only 8 cores/devices. It looks like jax.pmap() does not support this case. I just want to know what is the easiest way when I have more data than the devices in the parallellization in jax.
Of course it reports the error: compiling computation that requires 100 logical devices, but only 8 XLA devices are available (num_replicas=100). Is there any way for me to parallelize the calculation when I have 100 data but only 8 devices in jax? E.g., I want that if a core/device finishes a job and then it automatically starts the next job. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can do something like this with import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P
mesh = Mesh(mesh_utils.create_device_mesh(8,), axis_names=('i',))
key = jax.random.PRNGKey(0)
N = 800
x = jax.random.normal(key, (N,))
y = shard_map(jax.numpy.square, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) |
Beta Was this translation helpful? Give feedback.
You can do something like this with
shard_map
, which is the more general successor ofpmap
. Note however that it requires operating on an axis size which is an integer multiple of the number of devices. For example: