-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib][Training iteration fn] APEX conversion #22937
Conversation
… some code that needs to be deleted here, and some long running benchmarks that need to be run, but its looking good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few questions, thanks
@@ -19,7 +19,8 @@ | |||
class TestAlphaStar(unittest.TestCase): | |||
@classmethod | |||
def setUpClass(cls): | |||
ray.init(num_cpus=20) | |||
# ray.init(num_cpus=20) | |||
ray.init(local_mode=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops yeah I'll get rid of this my b.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious to see if there is impact on metrics. Really cool stuff. I'm not the most helpful reviewer yet I guess but please keep me include me on this in the future since it even touches the buffer topic! :)
…sult dictionaries
…led from number of replay actors:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now. Just one comment left: default arg for remote_fn should be sample.remote()
, not sample()
.
I'll run some benchmarks for APEX on Atari before merging ...
def setup(self, config: PartialTrainerConfigDict): | ||
super().setup(config) | ||
num_replay_buffer_shards = self.config["optimizer"]["num_replay_buffer_shards"] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add an if-block here for _disable_execution_plan=True, such that for the execution_plan
version of APEX, we don't create the learner thread and the actors twice.
rllib/agents/dqn/apex.py
Outdated
training_intensity = int(self.config["training_intensity"] or 1) | ||
for _ in range(training_intensity): | ||
temp = ray.get([actor.replay.remote() for actor in self.replay_actors]) | ||
temp = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be simplified into only one for-loop?
@@ -14,9 +15,13 @@ def asynchronous_parallel_requests( | |||
actors: List[ActorHandle], | |||
ray_wait_timeout_s: Optional[float] = None, | |||
max_remote_requests_in_flight_per_actor: int = 2, | |||
remote_fn: Optional[Callable[[Any, Optional[Any], Optional[Any]], Any]] = None, | |||
remote_fn: Optional[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Let's remove the Optional[...]
then.
@@ -51,6 +56,10 @@ def asynchronous_parallel_requests( | |||
(kwargs) as **kwargs to be passed to the `remote_fn`. | |||
E.g.: actors=[A, B], | |||
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. | |||
return_result_obj_ref_ids: If True, return the object ref IDs of the ready | |||
results, otherwise return the actual results. | |||
num_requests_to_launch: Number of remote requests to launch on each of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we describe this more precisely? Like add "if we have not reached max_remote_requests_in_flight_per_actor
yet, launch exactly num_requests_to_launch
on the respective actor, regardless of the current values of remote_requests_in_flight
". Something like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha -- yeah this is still "rate limited" by the max_remote_requests_in_flight_per_actor
Sorry, I know I already approved, but there are a few minor nits that we should address before merging (see my new comments). Nothing big, this is great PR! Thanks @avnishn . I'm currently running some benchmarks on Atari + APEX and will report the results here, but it looks promising. ... |
@@ -283,6 +375,260 @@ def add_apex_metrics(result: dict) -> dict: | |||
merged_op, workers, config, selected_workers=selected_workers | |||
).for_each(add_apex_metrics) | |||
|
|||
def get_samples_and_store_to_replay_buffers(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a docstring here and in all new APEXTrainer methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me
…to apex_training_itr
…_training_itr # Conflicts: # rllib/agents/dqn/apex.py
Apart of the training iteration function overhaul.
There's probably stuff broken with the logging (some keys not getting logged right now)
I had to modify
asynchronous_parallel_requests
to add some extra functionality for convenienceAnd the bug I fixed was that if there were multiple inflight requests for a single actor that were ready in the same call to
asynchronous_parallel_requests
, previously all but the last result was dropped from the return value ofasynchronous_parallel_requests
. But I made the return value a dict corresponding from actor ids to lists of results, as opposed to just results, and caught this. We definitely need some standalone tests for this function.Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.Tested on the tuned cartpole examples, and with minor hparam tuning (training intensity mainly) This implementation beats the execution plan implementation in runtime per number of sampled timesteps, and overall number of timesteps to optimal reward performance.