Replies: 1 comment 14 replies
-
So if I understand your question correctly, you have data that is replicated along hosts but sharded along the devices in the host? If that's the case, you can just pull them to host via Note that you will need jax and jaxlib 0.4.7 for this. |
Beta Was this translation helpful? Give feedback.
14 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
Sorry if this is discussed elsewhere, but I couldn't immediately find it.
I'm debugging some sharded data loading code and want to see if different hosts that should be loading the same data are in fact loading the same data (and inverse: hosts that should be loading different data are in fact loading different data)
ATM, the code is still GDA-based (GDA.from_callback), but I'm happy to do the migration to jax.Array (which I've been putting off).
So the code is basically
array = GlobalDeviceArray.from_callback(shape, mesh, pspec, local_data)
I've seen process_all_gather and copy_to_host_async, but they reasonably don't grab replicas
My current idea is to do something like: convert to jax.Array, loop through the global_shards, get the addressable buffer if it exists (doing something reasonable for the ones that don't exist), then do essentially a tiled process_all_gather on each shard index, and do my checks after that.
Is there anything better/preexisting?
Beta Was this translation helpful? Give feedback.
All reactions