-
I want to run parallel optimization and stop when all devices finished, using psum() in while condition to test if every node finished its run. However, program stuck with possibly one device wait for another (100% util of one device, 0% of the other) Code below: from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.pmap, axis_name="num_devices", static_broadcasted_argnums=(1,))
def _run_opt(state: jnp.ndarray, num_devices: int) -> jnp.ndarray:
def _parallel_step(s: jnp.ndarray) -> jnp.ndarray:
return s * 0.999
def _parallel_cond(s: jnp.ndarray) -> jnp.ndarray:
is_finish = jnp.all(s < 1e-3)
p_is_finish = jax.lax.psum(is_finish, axis_name="num_devices") == num_devices
return ~p_is_finish
state = jax.lax.while_loop(_parallel_cond, _parallel_step, state)
return state
def main():
size = 1000
state = jnp.array([
5.0 * jnp.ones(size),
2.0 * jnp.ones(size)
])
num_devices = jax.device_count()
state = _run_opt(state, num_devices)
print(state)
if __name__ == "__main__":
main() |
Beta Was this translation helpful? Give feedback.
Answered by
imoneoi
Feb 6, 2023
Replies: 1 comment
-
UPDATE: The issue seems to be resolved as the above code works well in |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
imoneoi
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
UPDATE: The issue seems to be resolved as the above code works well in
jax
version0.4.2