-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Make sure SlateQ works with GPU. #22738
Conversation
…with GPU training.
@@ -41,14 +41,15 @@ def build_slateq_model_and_distribution( | |||
Returns: | |||
Tuple consisting of 1) Q-model and 2) an action distribution class. | |||
""" | |||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the TorchPolicy will take care of all this.
I think we only have to make sure all tensors in the loss function are on the right device. ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I understand everything now.
it's the target_model that was the issue. we just need to use the correct target_model out of policy.target_models for things to work.
thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yeah, sorry, I should have thought about this. Yes, you can always do:
correct_target_model_to_use = policy.target_models[model]
...
@@ -154,7 +155,7 @@ def build_slateq_losses( | |||
|
|||
clicked = torch.sum(click_indicator, dim=1) | |||
mask_clicked_slates = clicked > 0 | |||
clicked_indices = torch.arange(batch_size) | |||
clicked_indices = torch.arange(batch_size).to(policy.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, this seems (almost) correct.
clicked_indices = torch.arange(batch_size).to(clicked.device). # <- some tensor that we know is already on one of the GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. this was actually how I did it originally. :)
updated.
@@ -320,7 +321,10 @@ def score_documents( | |||
torch.multiply(user_obs.unsqueeze(1), torch.stack(doc_obs, dim=1)), dim=2 | |||
) | |||
# Compile a constant no-click score tensor. | |||
score_no_click = torch.full(size=[user_obs.shape[0], 1], fill_value=no_click_score) | |||
# Make sure it lives on the same device as scores_per_candidate. | |||
score_no_click = torch.full( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great.
# [1, AxS] Useful for torch.take_along_dim() | ||
policy.slates_indices = policy.slates.reshape(-1).unsqueeze(0).to(policy.device) | ||
|
||
setup_mixins(policy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setup_late_mixins()
??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, reverted. I was trying to move policy.slates_indices to the correct device during late_setup.
but I am doing this the correct way now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for this important fix @gjoliver !
Why are these changes needed?
Create models and variables on proper device so SlateQ works with GPU training.
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.