Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[RUNTIME] initialize move_worker in driver process #513

Merged
merged 3 commits into from
Jun 16, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ class MeshHostWorker:
host."""

def __init__(self, server_address: str, num_hosts: int, host_id: int,
mesh_id: int, node_resource: str, runtime_random_seed: int):
mesh_id: int, move_worker: DaemonMoveWorker,
runtime_random_seed: int):
self.num_hosts = num_hosts
self.host_id = host_id
self.mesh_id = mesh_id
self.move_worker = move_worker
self.launched = False
self.distributed_client = (
xla_client._xla.get_distributed_runtime_client(
Expand Down Expand Up @@ -147,10 +149,6 @@ def __init__(self, server_address: str, num_hosts: int, host_id: int,
jax_tensor_to_cupy(device_put(
jnp.ones((1,), dtype=jnp.int8), d),
take_ownership=True))

# Launch the DaemonMoveWorker
cls = ray.remote(resources={node_resource: 1e-3})(DaemonMoveWorker)
self.move_worker = cls.remote()
self.launched = True

##### Buffer Related Functions #####
Expand Down Expand Up @@ -1146,14 +1144,18 @@ def _launch_xla_servers(self):
""), # For libnccl-net.so
})

# Launch a ray actor
# Launch the DaemonMoveWorker
node_resource = "node:" + self.host_info[i]["NodeManagerAddress"]
cls = ray.remote(resources={node_resource: 1e-3})(DaemonMoveWorker)
move_worker = cls.remote()

# Launch the MeshHostWorker
cls = ray.remote(num_gpus=self.num_devices_per_host,
resources={node_resource: 1e-3})(MeshHostWorker)
worker = cls.options(runtime_env={
"env_vars": env_vars
}).remote(self.server_address, self.num_hosts, i, self.mesh_id,
node_resource, global_config.runtime_random_seed)
move_worker, global_config.runtime_random_seed)
self.workers.append(worker)
self.launched = True

Expand Down