-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Issue 21334: Fix APPO when kl_loss is enabled. #21855
Conversation
Bug is with learner_info construction in our LearnerThread. This only shows up for APPO because APPO, IMPALA, and APEX are the only ones that use async Learnthread, while APPO is the only agent that updates kl loss.
rllib/agents/ppo/tests/test_appo.py
Outdated
@@ -47,6 +47,25 @@ def test_appo_compilation(self): | |||
check_compute_single_action(trainer) | |||
trainer.stop() | |||
|
|||
def test_appo_compilation_use_kl_loss(self): | |||
"""Test whether an APPOTrainer can be built with both frameworks.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Fix the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops done.
rllib/agents/ppo/tests/test_appo.py
Outdated
num_iterations = 2 | ||
|
||
for _ in framework_iterator(config, with_eager_tracing=True): | ||
print("w/ v-trace") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary here, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, I got rid of it.
rllib/agents/ppo/tests/test_appo.py
Outdated
for _ in framework_iterator(config, with_eager_tracing=True): | ||
print("w/ v-trace") | ||
_config = config.copy() | ||
_config["vtrace"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
if released: | ||
self.idle_tower_stacks.put(buffer_idx) | ||
|
||
self.outqueue.put((get_num_samples_loaded_into_buffer, learner_stats)) | ||
self.outqueue.put((get_num_samples_loaded_into_buffer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Thanks for the fix.
Is there an issue related to this PR? Could you change the title to: [RLlib] Issue xyz: ...
Sorry, saw the issue # now. |
Why are these changes needed?
Fix APPO agent when kl_loss is enabled.
value is now saved under different policy id keys. also we need to torch_mean() the stats for the torch policy.
Related issue number
Closes #21334
Checks
scripts/format.sh
to lint the changes in this PR.