Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Enhance add_policy test case. #28405

Merged
merged 4 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,10 @@ def remove_policy(
*,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[
Union[Set[PolicyID], Callable[[PolicyID, Optional[SampleBatchType]], bool]]
Union[
Container[PolicyID],
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
evaluation_workers: bool = True,
) -> None:
Expand Down
58 changes: 52 additions & 6 deletions rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs):
# Change the list of policies to train.
policies_to_train=[f"p{i}", f"p{i-1}"],
)
# Make sure new policy is part of remote workers in the
# worker set and the eval worker set.
assert pid in (
ray.get(
algo.workers.remote_workers()[0].apply.remote(
lambda w: list(w.policy_map.keys())
)
)
)
assert pid in (
ray.get(
algo.evaluation_workers.remote_workers()[0].apply.remote(
lambda w: list(w.policy_map.keys())
)
)
)
# Assert new policy is part of local worker (eval worker set does NOT
# have a local worker, only the main WorkerSet does).
pol_map = algo.workers.local_worker().policy_map
self.assertTrue(new_pol is not pol0)
for j in range(i + 1):
Expand All @@ -117,12 +135,14 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs):
test = pg.PG(config=config)
test.restore(checkpoint)

# Make sure evaluation worker also gets the restored policy.
def _has_policy(w):
return w.get_policy("p0") is not None
# Make sure evaluation worker also got the restored, added policy.
def _has_policies(w):
return (
w.get_policy("p0") is not None and w.get_policy(pid) is not None
)

self.assertTrue(
all(test.evaluation_workers.foreach_worker(_has_policy))
all(test.evaluation_workers.foreach_worker(_has_policies))
)

# Make sure algorithm can continue training the restored policy.
Expand All @@ -137,13 +157,39 @@ def _has_policy(w):

# Delete all added policies again from Algorithm.
for i in range(2, 0, -1):
pid = f"p{i}"
algo.remove_policy(
f"p{i}",
pid,
# Note that the complete signature of a policy_mapping_fn
# is: `agent_id, episode, worker, **kwargs`.
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
policy_mapping_fn=(
lambda agent_id, worker, episode, **kwargs: f"p{i - 1}"
),
# Update list of policies to train.
policies_to_train=[f"p{i - 1}"],
)
# Make sure removed policy is no longer part of remote workers in the
# worker set and the eval worker set.
assert pid not in (
ray.get(
algo.workers.remote_workers()[0].apply.remote(
lambda w: list(w.policy_map.keys())
)
)
)
assert pid not in (
ray.get(
algo.evaluation_workers.remote_workers()[0].apply.remote(
lambda w: list(w.policy_map.keys())
)
)
)
# Assert removed policy is no longer part of local worker
# (eval worker set does NOT have a local worker, only the main WorkerSet
# does).
pol_map = algo.workers.local_worker().policy_map
self.assertTrue(pid not in pol_map)
self.assertTrue(len(pol_map) == i)

algo.stop()

Expand Down