-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] Implement twin-Q net option for CQL. #47105
[RLlib; Offline RL] Implement twin-Q net option for CQL. #47105
Conversation
Signed-off-by: simonsays1980 <[email protected]>
@@ -90,8 +92,9 @@ def compute_loss_for_module( | |||
# Use the actions sampled from the current policy. | |||
Columns.ACTIONS: actions_curr, | |||
} | |||
# Note, if `twin_q` is `True`, `compute_q_values` computes the minimum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
if config.twin_q: | ||
td_error += torch.abs(q_twin_selected, q_selected_target) | ||
# Rescale the TD error | ||
td_error += 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be * 0.5
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch :D This should be multiplied by 0.5
@@ -144,15 +150,24 @@ def compute_loss_for_module( | |||
# Calculate the TD error. | |||
td_error = torch.abs(q_selected - q_selected_target) | |||
# TODO (simon): Add the Twin TD error | |||
if config.twin_q: | |||
td_error += torch.abs(q_twin_selected, q_selected_target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be torch.minimum
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, meant: Should this be torch.abs(q_twin_selected - q_selected_target)
@@ -144,15 +150,24 @@ def compute_loss_for_module( | |||
# Calculate the TD error. | |||
td_error = torch.abs(q_selected - q_selected_target) | |||
# TODO (simon): Add the Twin TD error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another good catch :)
* config.min_q_weight | ||
* config.temperature | ||
) | ||
cql_twin_loss - (q_twin_selected.mean()) * config.min_q_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be -=
?
Could you check all the math also once more? I don't want to miss anything important here :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes good catch. I hope I just missed the =
key and its not my eyes :)
…the math. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now. Thanks for this great PR and for double-checking @simonsays1980 .
Why are these changes needed?
This PR proposes the double Q trick for CQL to stabilize training. More specifically
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.