Skip to content

Commit

Permalink
[sgd] add placement group support (#17037)
Browse files Browse the repository at this point in the history
* [sgd] add placement group support

* add logic for removing placement group upon shutdown

* set placement group; add tests

* address comments - add timeout and improve error handling

* remove unused import

* mock SGD_PLACEMENT_GROUP_TIMEOUT_S
  • Loading branch information
matthewdeng authored Jul 20, 2021
1 parent 9ca6bda commit fef74aa
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 23 deletions.
8 changes: 8 additions & 0 deletions python/ray/util/sgd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ py_test(
deps = [":sgd_lib"]
)

py_test(
name = "test_placement_groups",
size = "medium",
srcs = ["tests/test_placement_groups.py"],
tags = ["exclusive"],
deps = [":sgd_lib"]
)

py_test(
name = "test_ptl",
size = "large",
Expand Down
206 changes: 206 additions & 0 deletions python/ray/util/sgd/tests/test_placement_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from unittest.mock import patch

import pytest
import ray
import torch.nn as nn
from ray import tune
from ray.cluster_utils import Cluster
from ray.tune.utils import merge_dicts
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.examples.train_example import (
model_creator, optimizer_creator, data_creator)
from ray.util.sgd.torch.training_operator import TrainingOperator

Operator = TrainingOperator.from_creators(
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)


@pytest.fixture
def ray_4_node_1_cpu():
cluster = Cluster()
for _ in range(4):
cluster.add_node(num_cpus=1)

ray.init(address=cluster.address)

yield

ray.shutdown()
cluster.shutdown()


@pytest.fixture
def ray_8_node_2_cpu():
cluster = Cluster()
for _ in range(8):
cluster.add_node(num_cpus=2)

ray.init(address=cluster.address)

yield

ray.shutdown()
cluster.shutdown()


@pytest.fixture
def ray_4_node_8_cpu():
cluster = Cluster()
for _ in range(4):
cluster.add_node(num_cpus=8)

ray.init(address=cluster.address)

yield

ray.shutdown()
cluster.shutdown()


def test_train_spread(ray_8_node_2_cpu):
"""Tests if workers are spread across nodes."""
assert ray.available_resources()["CPU"] == 16
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=7,
use_gpu=False,
)

assert ray.available_resources()["CPU"] == 9

node_id_set = set()
for actor_info in ray.state.actors().values():
node_id = actor_info["Address"]["NodeID"]
node_id_set.add(node_id)
assert len(node_id_set) == 7

trainer.shutdown()
assert ray.available_resources()["CPU"] == 16


@pytest.mark.parametrize("num_workers", [1, 7, 8, 15])
def test_tune_train_pack(ray_4_node_8_cpu, num_workers):
"""Tests if workers are colocated when running Tune."""

def custom_train_func(trainer, info):
train_stats = trainer.train(profile=True)
val_stats = trainer.validate(profile=True)
stats = merge_dicts(train_stats, val_stats)

actors = ray.state.actors().values()
assert len(actors) == num_workers + 1

node_id_set = set()
for actor_info in actors:
node_id = actor_info["Address"]["NodeID"]
node_id_set.add(node_id)

assert len(node_id_set) == 1 + num_workers // 8
return stats

TorchTrainable = TorchTrainer.as_trainable(
override_tune_step=custom_train_func,
**{
"training_operator_cls": Operator,
"num_workers": num_workers,
"use_gpu": False,
"backend": "gloo",
"config": {
"batch_size": 512,
"lr": 0.001
}
})

tune.run(
TorchTrainable,
num_samples=1,
stop={"training_iteration": 2},
verbose=1)


def test_shutdown(ray_8_node_2_cpu):
"""Tests if placement group is removed when worker group is shut down."""
assert ray.available_resources()["CPU"] == 16
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 0

trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=7,
use_gpu=False,
)
assert ray.available_resources()["CPU"] == 9
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 1
placement_group_id = list(placement_group_table)[0]
placement_group = placement_group_table[placement_group_id]
assert placement_group["strategy"] == "SPREAD"
assert placement_group["state"] == "CREATED"

trainer.shutdown()

assert ray.available_resources()["CPU"] == 16
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 1
placement_group = placement_group_table[placement_group_id]
assert placement_group["strategy"] == "SPREAD"
assert placement_group["state"] == "REMOVED"


def test_resize(ray_8_node_2_cpu):
"""Tests if placement group is removed when trainer is resized."""
assert ray.available_resources()["CPU"] == 16
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 0

trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=7,
use_gpu=False,
)

assert ray.available_resources()["CPU"] == 9
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 1
placement_group_id = list(placement_group_table)[0]
placement_group = placement_group_table[placement_group_id]
assert placement_group["state"] == "CREATED"

trainer._resize_worker_group(trainer.state_dict())

assert ray.available_resources()["CPU"] == 9
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 2
placement_group = placement_group_table[placement_group_id]
assert placement_group["state"] == "REMOVED"
placement_group_table_keys = list(placement_group_table)
placement_group_table_keys.remove(placement_group_id)
second_placement_group_id = placement_group_table_keys[0]
second_placement_group = placement_group_table[second_placement_group_id]
assert second_placement_group["state"] == "CREATED"

trainer.shutdown()

assert ray.available_resources()["CPU"] == 16
placement_group_table = ray.state.state.placement_group_table()
assert len(placement_group_table) == 2
placement_group = placement_group_table[placement_group_id]
assert placement_group["state"] == "REMOVED"
second_placement_group = placement_group_table[second_placement_group_id]
assert second_placement_group["state"] == "REMOVED"


@patch("ray.util.sgd.torch.worker_group.SGD_PLACEMENT_GROUP_TIMEOUT_S", 5)
def test_timeout(ray_4_node_1_cpu):
"""Tests that an error is thrown when placement group setup times out."""
with pytest.raises(TimeoutError):
trainer = TorchTrainer(
training_operator_cls=Operator, num_workers=7, use_gpu=False)
trainer.shutdown()


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", "-x", __file__]))
2 changes: 2 additions & 0 deletions python/ray/util/sgd/torch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
SCHEDULER_STEP_EPOCH = "epoch"
SCHEDULER_STEP_MANUAL = "manual"
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 1800)
SGD_PLACEMENT_GROUP_TIMEOUT_S = env_integer("SGD_PLACEMENT_GROUP_TIMEOUT_S",
100)

VALID_SCHEDULER_STEP = {
SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL
Expand Down
88 changes: 65 additions & 23 deletions python/ray/util/sgd/torch/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import ray
import torch
from ray.exceptions import RayActorError
from ray.util.placement_group import get_current_placement_group, \
remove_placement_group
from ray.util.sgd.torch.constants import SGD_PLACEMENT_GROUP_TIMEOUT_S
from ray.util.sgd.torch.distributed_torch_runner import \
LocalDistributedRunner, DistributedTorchRunner
from ray.util.sgd.torch.torch_runner import TorchRunner
Expand Down Expand Up @@ -135,16 +138,57 @@ def __init__(self, max_workers, params, dist_params, initialization_hook,
# The last time when this worker group was resized.
self._last_resize = float("-inf")

# This is set only if a placement group is created by the worker group.
self._worker_placement_group = None

def _create_placement_group(self, num_workers):
"""Creates a placement group for the workers.
If this worker is already in a placement group then a new one will
not be created. This is primarily for when Tune is the upstream and
will allocate resources for SGD workers.
If this worker is not in a placement group, a new one will be created
and set. The placement group will have a single bundle for each worker
and use the SPREAD strategy for an even distribution.
"""
pg = get_current_placement_group()
if pg is None:
bundle = {
"CPU": self._num_cpus_per_worker,
"GPU": int(self._use_gpu)
}
bundles = [bundle] * num_workers
pg = ray.util.placement_group(bundles, strategy="SPREAD")
logger.debug("Waiting for placement group to start.")
ready, _ = ray.wait(
[pg.ready()], timeout=SGD_PLACEMENT_GROUP_TIMEOUT_S)
if ready:
logger.debug("Placement group has started.")
else:
raise TimeoutError(
"Placement group creation timed out. Make sure "
"your cluster either has enough resources or use "
"an autoscaling cluster. Current resources "
"available: {}, resources requested by the "
"placement group: {}".format(ray.available_resources(),
pg.bundle_specs))
self._worker_placement_group = pg

def _init_dist_workers(self, num_workers):
"""Create `num_workers` remote workers."""
# Generate actor class
RemoteRunner = ray.remote(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu))(DistributedTorchRunner)

# Get placement group
self._create_placement_group(num_workers)
pg = self._worker_placement_group or "default"

# Start workers
self.remote_workers = [
RemoteRunner.remote(**{
RemoteRunner.options(placement_group=pg).remote(**{
**self._params,
**self._dist_params
}) for _ in range(num_workers)
Expand Down Expand Up @@ -352,34 +396,37 @@ def _shutdown_remote_workers(self):
def _terminate_remote_workers(self, cleanup):
"""Blocks on worker shutdown and then terminates each worker actor.
If graceful shutdown fails, forcefully kills all actors.
Return:
Whether or not workers were shutdown gracefully.
"""
try:
ray.get(cleanup)
[
worker.__ray_terminate__.remote()
for worker in self.remote_workers
]
return True
except RayActorError:
logger.warning("Failed to shutdown gracefully, forcing a "
"shutdown.")
self.reset()
logger.warning("Failed to shutdown gracefully.")
return False

def shutdown(self, force=False):
if not force:
cleanup = [
worker.shutdown.remote() for worker in self.remote_workers
]
self._terminate_remote_workers(cleanup)
else:
self.reset()
force_kill = force
if not force_kill:
cleanup = self._shutdown_remote_workers()
force_kill = not self._terminate_remote_workers(cleanup)
if force_kill:
for worker in self.remote_workers:
logger.debug(f"Killing worker {worker}.")
ray.kill(worker)
self.remote_workers = []
# Remove worker placement group.
if self._worker_placement_group:
remove_placement_group(self._worker_placement_group)
self._worker_placement_group = None

def reset(self):
for worker in self.remote_workers:
logger.debug(f"Killing worker {worker}.")
ray.kill(worker)
self.remote_workers = []
self.shutdown(force=True)

def should_scale_up(self):
worker_gap = self._max_workers - self.num_workers
Expand Down Expand Up @@ -584,13 +631,8 @@ def validate(self, num_steps=None, profile=False, info=None):
return worker_stats

def shutdown(self, force=False):
if not force:
cleanup = self.remote_worker_group._shutdown_remote_workers()
self.local_worker.shutdown()
self.remote_worker_group._terminate_remote_workers(cleanup)
else:
self.local_worker.shutdown()
self.remote_worker_group.reset()
self.local_worker.shutdown()
self.remote_worker_group.shutdown(force=force)

self.local_worker = None
self.remote_worker_group = DeactivatedWorkerGroup()
Expand Down

0 comments on commit fef74aa

Please sign in to comment.