-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Algorithm.add_policy()
should alternatively accept an already instantiated policy object.
#28637
[RLlib] Algorithm.add_policy()
should alternatively accept an already instantiated policy object.
#28637
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the really nice UX change.
rllib/algorithms/algorithm.py
Outdated
policy_mapping_fn=policy_mapping_fn, | ||
) | ||
# Then add a new instance to each remote worker. | ||
ray.get([w.apply.remote(fn) for w in self.workers.remote_workers()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we simply do
ray.get([w.add_policy.remote(**kwargs) for w in self.workers.remote_workers()])
?
rllib/algorithms/algorithm.py
Outdated
policy_mapping_fn=policy_mapping_fn, | ||
) | ||
else: | ||
fn(worker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not simply:
worker.add_policy(**kwargs)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned this up a little.
rllib/algorithms/algorithm.py
Outdated
f"{list(local_worker.policy_map.keys())}" | ||
) | ||
|
||
if policy_cls is not None and policy is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably need to check if both are None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/algorithms/algorithm.py
Outdated
# Run foreach_worker fn on all workers. | ||
self.workers.foreach_worker(fn) | ||
else: | ||
self.workers.foreach_worker(fn) | ||
|
||
# Update evaluation workers, if necessary. | ||
if evaluation_workers and self.evaluation_workers is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing that feels a bit weird is how come we can handle this for eval_workers with 1 line of code, but we need this many if_elses for rollout workers.
I wonder if we can do something similar for rollout workers, basically self.workers.foreach_worker(fn)
,
and WorkerSet.foreach_worker() will apply the fn locally on the local worker, or apply it remotely on all the remote workers.
that way, the logics about individual workers are capsulated behind the WorkerSet abstraction.
does that work??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eval workers don't have a local worker :)
For the (only one!) local worker, we should insert the policy directly into its policy_map, no re-creation of a new instance is required. That's the whole point of this PR, I guess.
The foreach_worker
utility is actually fine (handles local worker properly) and has nothing to do with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try to simplify the rest ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned up a little (removed the helper function entirely, not needed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, eval workers can have local worker too??
I often set evaluation_num_workers=0 for OPE, since we do OPE on trainer node anyways.
that will cause the evaluation_workers to use only a local worker?
also, one thing I am always a bit confused, if we already have a WorkerSet abstraction, why should Algorithm still manipulate individual local and remote workers itself. feel like it's better to have WorkerSet handle the underlying details?
for the specific case, I wonder if we should simply do something like:
rollout_workers.local_worker().add_policy(policy=policy)
rollout_workers.remote_worker().add_policy(policy_cls=type(policy), policy_state=...)
evaluation_workers.local_worker().add_policy(policy_cls=type(policy), policy_state=...)
evaluation_workers.remote_worker().add_policy(policy_cls=type(policy), policy_state=...)
then, the policy is only claimed by the local rollout worker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The evaluation worker set only has a local worker if evaluation_num_workers=0
. Otherwise, we'll skip generating it.
If you do: Algorithm.evaluate(), it will:
- first try to use the evaluation worker set (be it with local worker (evaluation_num_workers=0) or without local worker).
- then, if there is NO evaluation worker set at all, use the regular local worker. Note that this only works, if that local worker has an env (by default, we don't create one on the regular local worker)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right and we should create a WorkerSet.add/remove_policy
API. Then we can move all the code that's currently in Algorithm.add_policy
into the WorkerSet and in the algo, simply do:
def add_policy(...):
self.workers.add_policy()
self.evaluation_workers.add_policy()
return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/connectors/util.py
Outdated
@@ -78,8 +78,10 @@ def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): | |||
""" | |||
ctx: ConnectorContext = ConnectorContext.from_policy(policy) | |||
|
|||
policy.agent_connectors = get_agent_connectors_from_config(ctx, config) | |||
policy.action_connectors = get_action_connectors_from_config(ctx, config) | |||
if policy.agent_connectors is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I wonder why we need to check this.
are we calling this function for an existing policy that already has connectors restored??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if you have recovered this policy in here from a checkpoint? Then you would also already have the connectors inside this policy, correct?
In this case, you wouldn't want to re-create the connectors. Let me know if this chain of thought is wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't be calling this util if we are recovering a policy. this is only used when a policy is constructed from scratch.
do you mind removing these checks? if things fail somehow, I'd rather get an explicit signal than having actual problem concealed by these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, so you are saying we can assume 100% that every time we are recovering a policy say from a checkpoint, the connectors should already be in there? In this case, I added an assert to the utility and fixed the add_policy
method to NOT call this utility iff policy
was provided as an already instantiated one.
rllib/evaluation/rollout_worker.py
Outdated
" Policy IDs that are already in your policy map: " | ||
f"{list(self.workers.local_worker().policy_map.keys())}" | ||
) | ||
if policy_cls is not None and policy is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, one of these need to be not None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
policy_id: PolicyID, | ||
policy_cls: Type[Policy], | ||
policy_cls: Optional[Type[Policy]] = None, | ||
policy: Optional[Policy] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just want to mention a debate I had with myself while looking at this.
if we are willing to sacrifice a little bit of efficiency for the local worker, we can actually make PolicySpec the narrow waist of all this.
then we won't need to change rollout_worker or policy_map.
and in Algorithm.add_policy(), we would simple get the policy_spec and policy_state if we get passed a policy, instead of policy_cls.
I understand that will cause the policy to be created again for the local worker, wasting compute and mem. but it seems like we can greatly simplify all these logics if we can have a narrow waist for add_policy().
just some thoughts that I figure I'd mention, since I can't make up my mind either.
let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure about this. Was thinking about this, too :)
But the problem is that the expectation (mental model of user) of doing my_algo.add_policy(my_policy_instance)
is that my_policy_instance
is actually as-is being incorporated somehow to the algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah maybe.
if we were writing C++, we can make this super clear by declare the argument const Policy& policy
if the policy is meant to be duplicated, or "Policy* policy" if the ownership of the policy is supposed to be transferred.
python is 🤷♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤷♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
Signed-off-by: sven1977 <[email protected]>
…policy_takes_policy_instance
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
policy_id: PolicyID, | ||
policy_cls: Type[Policy], | ||
policy_cls: Optional[Type[Policy]] = None, | ||
policy: Optional[Policy] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
@@ -231,6 +244,204 @@ def sync_weights( | |||
elif self.local_worker() is not None and global_vars is not None: | |||
self.local_worker().set_global_vars(global_vars) | |||
|
|||
def add_policy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New API for WorkerSet:
- add_policy(): Similar to
Algorithm.add_policy()
. - WorkerSet.add_policy_to_workers(): New static helper utility for adding a new policy (by instance or options) to a list of (local and/or remote) workers.
policy_spec.config, # overrides. | ||
merged_conf, | ||
) | ||
if policy is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on whether the policy is given as an already instantiated object or not, use either create_policy() or insert_policy(). Note that create_policy
also uses insert_policy
internally now.
) | ||
): | ||
create_connectors_for_policy(self.policy_map[policy_id], self.policy_config) | ||
# Create connectors for the new policy, if necessary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified this if-block here. Some checks were superfluous.
# Change the list of policies to train. | ||
policies_to_train=[f"p{i}", f"p{i-1}"], | ||
) | ||
print(f"Adding policy {pid} ...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also test adding a new policy by instance now.
@@ -1547,7 +1547,8 @@ def set_weights(self, weights: Dict[PolicyID, dict]): | |||
def add_policy( | |||
self, | |||
policy_id: PolicyID, | |||
policy_cls: Type[Policy], | |||
policy_cls: Optional[Type[Policy]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified the Algorithm.add_policy()
method by only using the new WorkerSet APIs. No more micro-handling individual workers policy_maps here.
@@ -231,7 +231,7 @@ class directly. Note that this arg can also be specified via | |||
""" | |||
|
|||
# User provided (partial) config (this may be w/o the default | |||
# Trainer's Config object). Will get merged with AlgorithmConfig() | |||
# Algorithm's Config object). Will get merged with AlgorithmConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found a few of these old Trainer
in the comments.
rllib/algorithms/algorithm.py
Outdated
@@ -1588,46 +1596,69 @@ def add_policy( | |||
returns False) will not be updated. | |||
evaluation_workers: Whether to add the new policy also | |||
to the evaluation WorkerSet. | |||
workers: A list of RolloutWorker/ActorHandles (remote | |||
worker_list: A list of RolloutWorker/ActorHandles (remote |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a random request, do you mind staying with workers
for now? since this is a simple name change.
as part of the elastic training PR, I am getting rid of all these places where we are accessing underlying RolloutWorkers outside of WorkerSet.
so if we deprecate this today, in a few days, I am gonna have to deprecate worker_list too, and we will have 2 deprecated fields here.
rllib/algorithms/algorithm.py
Outdated
# Worker list is explicitly provided -> Use only those workers (local or remote) | ||
# specified. | ||
if worker_list is not None: | ||
RolloutWorker.add_policy_to_workers( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this change, this actually works well with my fault tolerance PR. I will make WorkerSet.add_policy_to_workers()
the only way to go about this in the future actually.
one problem though, this should be WorkerSet.
not RolloutWorker.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yeah, great catch!
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure all the tests pass! Thanks.
…dy instantiated policy object. (ray-project#28637) Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: sven1977 [email protected]
Algorithm.add_policy()
should alternatively accept an already instantiated policy object.RolloutWorker.add_policy()
.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.