Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Agent collector time complexity reduction #31693

Merged

Conversation

ArturNiederfahrenhorst
Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst commented Jan 16, 2023

Why are these changes needed?

In the light of the recent QMix regression for connectors, we have found that this regression affects QMix because of the very short episode lengths in the two-step-game. These lead to very frequent calls to AgentCollector.build_for_training().
This PR tries to optimize build_for_training() (and some other things that I found along the way) for time complexity.

Changes contained in this PR lead to speeding up AgentCollector.build_for_training() roughly indicated by the following metrics (average over 500 samples):
two-step: Single-agent episodes of length two, no recurrency.
ten-step: Single-agent episodes of length ten, no recurrency.
sixtyone-step: Single-agent episodes of length ten, with recurrency (and padding).
Screenshot 2023-01-16 at 12 46 36

mean_raw_obs_processing is where we spend much of our time in env_runnver_v2 and the source of the regression in question.

For two step game this pans out as follows for the mean_obs_preprocessing time:
blue w/o connectors
red w/ connectors
orange w/ connectors and optimizations
Screenshot 2023-01-16 at 23 54 17
... and as follows for the overall throughput:
Screenshot 2023-01-16 at 23 55 58

For the r2d2 compilation test (mean episode length ~ 20), this pans out as follows:
blue is w/o connectors
oragen is w/ connectors
light-blue w/ connectors and optimizations
Screenshot 2023-01-17 at 00 04 07

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@@ -295,7 +296,7 @@ def postprocess_episode(

if (
not pre_batch.is_single_trajectory()
or len(set(pre_batch[SampleBatch.EPS_ID])) > 1
or len(np.unique(pre_batch[SampleBatch.EPS_ID])) > 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the above changes, EPS_IDs turn out to be np arrays as well, so set does not work here anymore.

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if the tests pass.

timer = self.timers[str(c)]
with timer:
ac_data = c(ac_data)
timer.push_units_processed(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remind me what timer.push_units_processed(1) does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went back to look at the implementation for this and found that it is only needed for throughput measurements. Since the mean time is not calculated over the units processed but over the number of timings. Thanks!

@@ -19,10 +21,17 @@
class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
super().__init__(ctx, connectors)
self.timers = defaultdict(_Timer)

def reset(self, env_id: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are your thoughts on implementing timer capabilities in the baseclass connectors vs. here? (not saying we should, just want to hear your argument).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pro implementing in baseclass:

  • It's generally a good idea to pull functionality down to a lower level if it does not increase complexity
    Against implementing in baseclass:
  • It enlargens the interface between the pipeline and the connectors if we call the timers from the pipeline or, alternatively, we'd have to assume that timers are correctly handled by someone subclassing Connectors in a transform method.

I think the solution this PR is at introduces less complexity compared to putting this in Connectors.

@ArturNiederfahrenhorst ArturNiederfahrenhorst added tests-ok The tagger certifies test failures are unrelated and assumes personal liability. and removed do-not-merge Do not merge this PR! labels Jan 18, 2023
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
@ArturNiederfahrenhorst ArturNiederfahrenhorst removed the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Jan 18, 2023
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 quick questions

@@ -278,8 +277,7 @@ def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> Non
AgentCollector._next_unroll_id += 1

# Next obs -> obs.
# TODO @kourosh: remove the in-place operations and get rid of this deepcopy.
values = deepcopy(input_values)
values = {k: v for k, v in input_values.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use copy.copy()?

@@ -28,7 +28,6 @@ def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
timer = self.timers[str(c)]
with timer:
ac_data = c(ac_data)
timer.push_units_processed(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, we shouldn't get rid of these? same below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we actually don't need these after kourosh asked me about them -> #31693 (comment)

I just executed code from this pr and took the following screenshot just to make sure that the timer actually works as expected.
When calling .mean(), the timer does not care of processed units - we don't need it.
Screenshot 2023-01-18 at 15 54 33

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's right, it's part of the with statements.

# length. This branch takes more time than simply picking
# slices we try to avoid it.
element_at_t = []
for index in inds:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe extract this for loop into a small inline function, so the code looks a little better.
the level is nested-ness is a bit nuts :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Thanks :)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
@@ -28,7 +28,6 @@ def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
timer = self.timers[str(c)]
with timer:
ac_data = c(ac_data)
timer.push_units_processed(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's right, it's part of the with statements.

@gjoliver gjoliver merged commit fb3c2b5 into ray-project:master Jan 19, 2023
andreapiso pushed a commit to andreapiso/ray that referenced this pull request Jan 22, 2023
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Andrea Pisoni <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants