-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Agent collector time complexity reduction #31693
[RLlib] Agent collector time complexity reduction #31693
Conversation
@@ -295,7 +296,7 @@ def postprocess_episode( | |||
|
|||
if ( | |||
not pre_batch.is_single_trajectory() | |||
or len(set(pre_batch[SampleBatch.EPS_ID])) > 1 | |||
or len(np.unique(pre_batch[SampleBatch.EPS_ID])) > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the above changes, EPS_IDs turn out to be np arrays as well, so set does not work here anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if the tests pass.
rllib/connectors/action/pipeline.py
Outdated
timer = self.timers[str(c)] | ||
with timer: | ||
ac_data = c(ac_data) | ||
timer.push_units_processed(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you remind me what timer.push_units_processed(1)
does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went back to look at the implementation for this and found that it is only needed for throughput measurements. Since the mean time is not calculated over the units processed but over the number of timings. Thanks!
rllib/connectors/action/pipeline.py
Outdated
@@ -19,10 +21,17 @@ | |||
class ActionConnectorPipeline(ConnectorPipeline, ActionConnector): | |||
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]): | |||
super().__init__(ctx, connectors) | |||
self.timers = defaultdict(_Timer) | |||
|
|||
def reset(self, env_id: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are your thoughts on implementing timer capabilities in the baseclass connectors vs. here? (not saying we should, just want to hear your argument).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pro implementing in baseclass:
- It's generally a good idea to pull functionality down to a lower level if it does not increase complexity
Against implementing in baseclass: - It enlargens the interface between the pipeline and the connectors if we call the timers from the pipeline or, alternatively, we'd have to assume that timers are correctly handled by someone subclassing Connectors in a transform method.
I think the solution this PR is at introduces less complexity compared to putting this in Connectors.
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
b85c5da
to
60614a1
Compare
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 quick questions
@@ -278,8 +277,7 @@ def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> Non | |||
AgentCollector._next_unroll_id += 1 | |||
|
|||
# Next obs -> obs. | |||
# TODO @kourosh: remove the in-place operations and get rid of this deepcopy. | |||
values = deepcopy(input_values) | |||
values = {k: v for k, v in input_values.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just use copy.copy()?
@@ -28,7 +28,6 @@ def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: | |||
timer = self.timers[str(c)] | |||
with timer: | |||
ac_data = c(ac_data) | |||
timer.push_units_processed(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, we shouldn't get rid of these? same below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized we actually don't need these after kourosh asked me about them -> #31693 (comment)
I just executed code from this pr and took the following screenshot just to make sure that the timer actually works as expected.
When calling .mean(), the timer does not care of processed units - we don't need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's right, it's part of the with statements.
# length. This branch takes more time than simply picking | ||
# slices we try to avoid it. | ||
element_at_t = [] | ||
for index in inds: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe extract this for loop into a small inline function, so the code looks a little better.
the level is nested-ness is a bit nuts :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Thanks :)
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
@@ -28,7 +28,6 @@ def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: | |||
timer = self.timers[str(c)] | |||
with timer: | |||
ac_data = c(ac_data) | |||
timer.push_units_processed(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's right, it's part of the with statements.
Signed-off-by: Artur Niederfahrenhorst <[email protected]> Signed-off-by: Andrea Pisoni <[email protected]>
Why are these changes needed?
In the light of the recent QMix regression for connectors, we have found that this regression affects QMix because of the very short episode lengths in the two-step-game. These lead to very frequent calls to AgentCollector.build_for_training().
This PR tries to optimize build_for_training() (and some other things that I found along the way) for time complexity.
Changes contained in this PR lead to speeding up AgentCollector.build_for_training() roughly indicated by the following metrics (average over 500 samples):
two-step: Single-agent episodes of length two, no recurrency.
ten-step: Single-agent episodes of length ten, no recurrency.
sixtyone-step: Single-agent episodes of length ten, with recurrency (and padding).
mean_raw_obs_processing is where we spend much of our time in env_runnver_v2 and the source of the regression in question.
For two step game this pans out as follows for the mean_obs_preprocessing time:
blue w/o connectors
red w/ connectors
orange w/ connectors and optimizations
... and as follows for the overall throughput:
For the r2d2 compilation test (mean episode length ~ 20), this pans out as follows:
blue is w/o connectors
oragen is w/ connectors
light-blue w/ connectors and optimizations
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.