-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix type hints for original_batches
in callbacks.
#24214
[RLlib] Fix type hints for original_batches
in callbacks.
#24214
Conversation
Signed-off-by: Xuehai Pan <[email protected]>
@@ -196,7 +196,7 @@ def on_postprocess_trajectory( | |||
policy_id: PolicyID, | |||
policies: Dict[PolicyID, Policy], | |||
postprocessed_batch: SampleBatch, | |||
original_batches: Dict[AgentID, SampleBatch], | |||
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what, that's surprising ...
how can we mix these 2 things together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ray/rllib/evaluation/collectors/simple_list_collector.py
Lines 824 to 832 in 18c269c
pre_batches = {} | |
for (eps_id, agent_id), collector in self.agent_collectors.items(): | |
# Build only if there is data and agent is part of given episode. | |
if collector.agent_steps == 0 or eps_id != episode_id: | |
continue | |
pid = self.agent_key_to_policy_id[(eps_id, agent_id)] | |
policy = self.policy_map[pid] | |
pre_batch = collector.build(policy.view_requirements) | |
pre_batches[agent_id] = (policy, pre_batch) |
pre_batches
is a dict with value type Tuple[Policy, SampleBatch]
. Then it is fed to on_postprocess_trajectory
as name original_batches
.
ray/rllib/evaluation/collectors/simple_list_collector.py
Lines 913 to 925 in 18c269c
for agent_id, post_batch in sorted(post_batches.items()): | |
agent_key = (episode_id, agent_id) | |
pid = self.agent_key_to_policy_id[agent_key] | |
policy = self.policy_map[pid] | |
self.callbacks.on_postprocess_trajectory( | |
worker=get_global_worker(), | |
episode=episode, | |
agent_id=agent_id, | |
policy_id=pid, | |
policies=self.policy_map, | |
postprocessed_batch=post_batch, | |
original_batches=pre_batches, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh it's Tuple, not Union, I got scared.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) @gjoliver
I think there is only very few places in RLlib where we mix different types e.g. in a return value (for example in the sampler code _process_observations()
) and no, we probably shouldn't do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sven1977 can you help merge?
Nice fix. Thanks @XuehaiPan , really appreciate all your help on RLlib :) |
original_batches
original_batches
.
original_batches
.original_batches
in callbacks.
Why are these changes needed?
Fix type hints for
original_batches
inon_postprocess_trajectory
.Related issue number
N/A
Checks
scripts/format.sh
to lint the changes in this PR.