Skip to content

Commit

Permalink
[RLlib] Fix: Recovered eval worker should use eval-config's policy_ma…
Browse files Browse the repository at this point in the history
…pping_fn and policy_to_train fn, not the main train workers' ones. (ray-project#33648)

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: elliottower <[email protected]>
  • Loading branch information
sven1977 authored and elliottower committed Apr 22, 2023
1 parent 9c948ce commit 8d0ae4c
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 84 deletions.
Empty file removed rllib/agents/__init__.py
Empty file.
Empty file removed rllib/agents/trainer.py
Empty file.
56 changes: 20 additions & 36 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,6 @@ def setup(self, config: AlgorithmConfig) -> None:
default_policy_class=self.get_default_policy_class(self.config),
config=self.evaluation_config,
num_workers=self.config.evaluation_num_workers,
# Don't even create a local worker if num_workers > 0.
local_worker=False,
logdir=self.logdir,
)

Expand Down Expand Up @@ -1301,14 +1299,14 @@ def restore_workers(self, workers: WorkerSet):
workers: The WorkerSet to restore. This may be Rollout or Evaluation
workers.
"""
# If `workers` is None, or
# 1. `workers` (WorkerSet) does not have a local worker, and
# 2. `self.workers` (WorkerSet used for training) does not have a local worker
# -> we don't have a local worker to get state from, so we can't recover
# remote worker in this case.
if not workers or (
not workers.local_worker() and not self.workers.local_worker()
):
# If workers does not exist, or
# 1. this WorkerSet does not have a local worker, and
# 2. self.workers (rollout worker set) does not have a local worker,
# we don't have a local worker to get state from.
# We can't recover remote worker in this case.
return

# This is really cheap, since probe_unhealthy_workers() is a no-op
Expand All @@ -1317,12 +1315,16 @@ def restore_workers(self, workers: WorkerSet):

if restored:
from_worker = workers.local_worker() or self.workers.local_worker()
state = ray.put(from_worker.get_state())
# Get the state of the correct (reference) worker. E.g. The local worker
# of the main WorkerSet.
state_ref = ray.put(from_worker.get_state())

# By default, entire local worker state is synced after restoration
# to bring these workers up to date.
workers.foreach_worker(
func=lambda w: w.set_state(ray.get(state)),
func=lambda w: w.set_state(ray.get(state_ref)),
remote_worker_ids=restored,
# Don't update the local_worker, b/c it's the one we are synching from.
local_worker=False,
timeout_seconds=self.config.worker_restore_timeout_s,
# Bring back actor after successful state syncing.
Expand Down Expand Up @@ -1861,12 +1863,12 @@ def add_policy(

if workers is not DEPRECATED_VALUE:
deprecation_warning(
old="workers",
old="Algorithm.add_policy(.., workers=..)",
help=(
"The `workers` argument to `Algorithm.add_policy()` is deprecated "
"and no-op now. Please do not use it anymore."
"The `workers` argument to `Algorithm.add_policy()` is deprecated! "
"Please do not use it anymore."
),
error=False,
error=True,
)

self.workers.add_policy(
Expand Down Expand Up @@ -2580,7 +2582,6 @@ def __setstate__(self, state) -> None:
# there in case they are used for evaluation purpose.
self.evaluation_workers.foreach_worker(
lambda w: w.set_state(ray.get(remote_state)),
local_worker=False,
healthy_only=False,
)
# If necessary, restore replay data as well.
Expand Down Expand Up @@ -3079,29 +3080,12 @@ def _record_usage(self, config):
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)

@Deprecated(new="construct WorkerSet(...) instance directly", error=False)
def _make_workers(
self,
*,
env_creator: EnvCreator,
validate_env: Optional[Callable[[EnvType, EnvContext], None]],
policy_class: Type[Policy],
config: AlgorithmConfigDict,
num_workers: int,
local_worker: bool = True,
) -> WorkerSet:
return WorkerSet(
env_creator=env_creator,
validate_env=validate_env,
default_policy_class=policy_class,
config=config,
num_workers=num_workers,
local_worker=local_worker,
logdir=self.logdir,
)
@Deprecated(new="construct WorkerSet(...) instance directly", error=True)
def _make_workers(self, *args, **kwargs):
pass

def validate_config(self, config) -> None:
# TODO: Deprecate. All logic has been moved into the AlgorithmConfig classes.
@Deprecated(new="AlgorithmConfig.validate()", error=False)
def validate_config(self, config):
pass

@staticmethod
Expand Down
10 changes: 6 additions & 4 deletions rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs):
# Make sure new policy is part of remote workers in the
# worker set and the eval worker set.
self.assertTrue(
algo.workers.foreach_worker(func=lambda w: pid in w.policy_map)[0]
all(algo.workers.foreach_worker(func=lambda w: pid in w.policy_map))
)
self.assertTrue(
algo.evaluation_workers.foreach_worker(
func=lambda w: pid in w.policy_map
)[0]
all(
algo.evaluation_workers.foreach_worker(
func=lambda w: pid in w.policy_map
)
)
)

# Assert new policy is part of local worker (eval worker set does NOT
Expand Down
100 changes: 75 additions & 25 deletions rllib/algorithms/tests/test_worker_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def setUpClass(cls) -> None:

register_env("fault_env", lambda c: FaultInjectEnv(c))
register_env(
"multi-agent-fault_env", lambda c: make_multi_agent(FaultInjectEnv)(c)
"multi_agent_fault_env", lambda c: make_multi_agent(FaultInjectEnv)(c)
)

@classmethod
Expand Down Expand Up @@ -305,50 +305,70 @@ def _do_test_fault_fatal(self, config, fail_eval=False):
self.assertRaises(Exception, lambda: a.train())
a.stop()

def _do_test_fault_fatal_but_recreate(self, config):
def _do_test_fault_fatal_but_recreate(self, config, multi_agent=False):
# Counter that will survive restarts.
COUNTER_NAME = "_do_test_fault_fatal_but_recreate"
COUNTER_NAME = (
f"_do_test_fault_fatal_but_recreate{'_ma' if multi_agent else ''}"
)
counter = Counter.options(name=COUNTER_NAME).remote()

# Test raises real error when out of workers.
config.num_rollout_workers = 1
config.evaluation_num_workers = 1
config.evaluation_interval = 1
config.env = "fault_env"
config.evaluation_config = {
"recreate_failed_workers": True,
config.env = "fault_env" if not multi_agent else "multi_agent_fault_env"
config.evaluation_config = AlgorithmConfig.overrides(
recreate_failed_workers=True,
# 0 delay for testing purposes.
"delay_between_worker_restarts_s": 0,
delay_between_worker_restarts_s=0,
# Make eval worker (index 1) fail.
"env_config": {
env_config={
"bad_indices": [1],
"failure_start_count": 3,
"failure_stop_count": 4,
"counter": COUNTER_NAME,
},
}
**(
dict(
policy_mapping_fn=(
lambda aid, episode, worker, **kwargs: (
# Allows this test to query this
# different-from-training-workers policy mapping fn.
"This is the eval mapping fn"
if episode is None
else "main"
if episode.episode_id % 2 == aid
else "p{}".format(np.random.choice([0, 1]))
)
)
)
if multi_agent
else {}
),
)

for _ in framework_iterator(config, frameworks=("tf2", "torch")):
# Reset interaction counter.
ray.wait([counter.reset.remote()])

a = config.build()

a.train()
wait_for_restore()
a.train()

self.assertEqual(a.workers.num_healthy_remote_workers(), 1)
self.assertEqual(a.evaluation_workers.num_healthy_remote_workers(), 1)

# This should also work several times.
a.train()
wait_for_restore()
a.train()

self.assertEqual(a.workers.num_healthy_remote_workers(), 1)
self.assertEqual(a.evaluation_workers.num_healthy_remote_workers(), 1)

for _ in range(2):
a.train()
wait_for_restore()
a.train()

self.assertEqual(a.workers.num_healthy_remote_workers(), 1)
self.assertEqual(a.evaluation_workers.num_healthy_remote_workers(), 1)
if multi_agent:
# Make a dummy call to the eval worker's policy_mapping_fn and
# make sure the restored eval worker received the correct one from
# the eval config (not the main workers' one).
test = a.evaluation_workers.foreach_worker(
lambda w: w.policy_mapping_fn(0, None, None)
)
self.assertEqual(test[0], "This is the eval mapping fn")
a.stop()

def test_fatal(self):
Expand Down Expand Up @@ -450,6 +470,36 @@ def test_recreate_eval_workers_parallel_to_training_w_actor_manager(self):

self._do_test_fault_fatal_but_recreate(config)

def test_recreate_eval_workers_parallel_to_training_w_actor_manager_and_multi_agent(
self,
):
# Test the case where all eval workers fail on a multi-agent env with
# different `policy_mapping_fn` in eval- vs train workers, but we chose
# to recover.
config = (
PGConfig()
.multi_agent(
policies={"main", "p0", "p1"},
policy_mapping_fn=(
lambda aid, episode, worker, **kwargs: (
"main"
if episode.episode_id % 2 == aid
else "p{}".format(np.random.choice([0, 1]))
)
),
)
.evaluation(
evaluation_num_workers=1,
enable_async_evaluation=True,
evaluation_parallel_to_training=True,
evaluation_duration="auto",
)
.training(model={"fcnet_hiddens": [4]})
.debugging(worker_cls=ForwardHealthCheckToEnvWorker)
)

self._do_test_fault_fatal_but_recreate(config, multi_agent=True)

def test_eval_workers_failing_fatal(self):
# Test the case where all eval workers fail (w/o recovery).
self._do_test_fault_fatal(
Expand Down Expand Up @@ -527,7 +577,7 @@ def test_policies_are_restored_on_recovered_worker(self):
model={"fcnet_hiddens": [4]},
)
.environment(
env="multi-agent-fault_env",
env="multi_agent_fault_env",
env_config={
# Make both worker idx=1 and 2 fail.
"bad_indices": [1, 2],
Expand Down Expand Up @@ -839,7 +889,7 @@ def test_multi_agent_env_eval_workers_fault_but_restore_env(self):
model={"fcnet_hiddens": [4]},
)
.environment(
env="multi-agent-fault_env",
env="multi_agent_fault_env",
# Workers do not fault and no fault tolerance.
env_config={},
disable_env_checking=True,
Expand Down
10 changes: 6 additions & 4 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,12 +522,12 @@ def add_policy(

if workers is not DEPRECATED_VALUE:
deprecation_warning(
old="workers",
old="WorkerSet.add_policy(.., workers=..)",
help=(
"The `workers` argument to `WorkerSet.add_policy()` is deprecated "
"and a no-op now. Please do not use it anymore."
"The `workers` argument to `WorkerSet.add_policy()` is deprecated! "
"Please do not use it anymore."
),
error=False,
error=True,
)

if (policy_cls is None) == (policy is None):
Expand Down Expand Up @@ -576,6 +576,7 @@ def _create_new_policy_fn(worker: RolloutWorker):
worker.add_policy(**new_policy_instance_kwargs)

if self.local_worker() is not None:
# Add policy directly by (already instantiated) object.
if policy is not None:
self.local_worker().add_policy(
policy_id=policy_id,
Expand All @@ -584,6 +585,7 @@ def _create_new_policy_fn(worker: RolloutWorker):
policies_to_train=policies_to_train,
module_spec=module_spec,
)
# Add policy by constructor kwargs.
else:
self.local_worker().add_policy(**new_policy_instance_kwargs)

Expand Down
20 changes: 8 additions & 12 deletions rllib/examples/custom_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,19 @@ def custom_eval_function(algorithm, eval_workers):
Returns:
metrics: Evaluation metrics dict.
"""

# We configured 2 eval workers in the training config.
funcs = [
lambda w: w.foreach_env(lambda env: env.set_corridor_length(4)),
lambda w: w.foreach_env(lambda env: env.set_corridor_length(7)),
]

# Set different env settings for each worker. Here we use a fixed config,
# which also could have been computed in each worker by looking at
# env_config.worker_index (printed in SimpleCorridor class above).
eval_workers.foreach_worker(func=funcs)
# Set different env settings for each worker. Here we use the worker's
# `worker_index` property.
eval_workers.foreach_worker(
func=lambda w: w.foreach_env(
lambda env: env.set_corridor_length(4 if w.worker_index == 1 else 7)
)
)

for i in range(5):
print("Custom evaluation round", i)
# Calling .sample() runs exactly one episode per worker due to how the
# eval workers are configured.
eval_workers.foreach_worker(func=lambda w: w.sample())
eval_workers.foreach_worker(func=lambda w: w.sample(), local_worker=False)

# Collect the accumulated episodes on the workers, and then summarize the
# episode stats into a metrics dict.
Expand Down
8 changes: 5 additions & 3 deletions rllib/utils/actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def num_actors(self) -> int:
@DeveloperAPI
def num_healthy_actors(self) -> int:
"""Return the number of healthy remote actors."""
return sum([s.is_healthy for s in self.__remote_actor_states.values()])
return sum(s.is_healthy for s in self.__remote_actor_states.values())

@DeveloperAPI
def total_num_restarts(self) -> int:
Expand Down Expand Up @@ -785,10 +785,12 @@ def probe_unhealthy_actors(
) -> List[int]:
"""Ping all unhealthy actors to try bringing them back.
Returns:
A list of actor ids that are restored.
Args:
timeout_seconds: Timeout to avoid pinging hanging workers indefinitely.
mark_healthy: Whether to mark actors healthy if they respond to the ping.
Returns:
A list of actor ids that are restored.
"""
unhealthy_actor_ids = [
actor_id
Expand Down

0 comments on commit 8d0ae4c

Please sign in to comment.