-
Hello, I am getting this error in a jax distributed code (only when I use jit)
my function uses a GSPMD sharded array and there is a few shardmaps aswell. Without the jit, the function is running but obviously I want to have a jitted function that runs in a multi-host setup. How can I debug this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I made the most minimal example possible import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
import jax.numpy as jnp
import jax.lax as lax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from functools import partial
pdims = (2 , 2)
mesh_shape = [4 , 4 , 4]
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('y', 'z'))
sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))
local_mesh_shape = [
mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0], mesh_shape[2]
]
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
z = jax.make_array_from_single_device_arrays(
shape=mesh_shape,
sharding=sharding,
arrays=[jax.random.normal(key, local_mesh_shape, dtype='float32')])
pos = jax.make_array_from_callback(
shape=tuple(mesh_shape + [3]),
sharding=sharding,
data_callback=lambda x: jnp.stack(jnp.meshgrid(
jnp.arange(mesh_shape[0])[x[0]],
jnp.arange(mesh_shape[1])[x[1]],
jnp.arange(mesh_shape[2]),
indexing='ij'),
axis=-1))
# local position correction
correct_y = -local_mesh_shape[1] * (rank // pdims[0])
correct_z = -local_mesh_shape[0] * (rank % pdims[0])
@partial(shard_map,
mesh=mesh,
in_specs=(P('z', 'y',None), P('z', 'y' ,None)),
out_specs=(P('z', 'y',None)))
def cic_read(particle_mesh, positions):
positions = positions.at[:, :, :, 1].add(correct_y)
positions = positions.at[:, :, :, 0].add(correct_z)
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
indices = [jnp.mod(neighboor_coords[..., i] , particle_mesh.shape[i]) for i in range(3)]
field = (particle_mesh[indices[0] , indices[1] , indices[2]] * kernel).sum(axis=-1)
# print(f"Field shape {field.shape}")
field = particle_mesh.reshape(*particle_mesh.shape)
return field
field = cic_read(z, pos)
print(f"Rank {rank} Field shape {field.shape}") The error is in here
I am working inly with slices .. a similar code works with single host does the classic numpy indexing not work? |
Beta Was this translation helpful? Give feedback.
Aswered in this issue #22218
By @yashk2810