Skip to content

Commit

Permalink
[RLlib] Make sure SlateQ works with GPU. (#22738)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jun Gong authored Mar 4, 2022
1 parent b609bdf commit e765915
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
19 changes: 10 additions & 9 deletions rllib/agents/slateq/slateq_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,16 @@ def build_slateq_losses(
# TODO: Find out, whether it's correct here to use obs, not next_obs!
# Dopamine uses obs, then next_obs only for the score.
# next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
next_q_values = policy.target_model.get_q_values(user_obs, doc_obs)
next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs)
scores, score_no_click = score_documents(user_next_obs, doc_next_obs)

# next_q_values_slate.shape: [B, A, S]
next_q_values_slate = torch.take_along_dim(
next_q_values, policy.slates_indices, dim=1
).reshape([-1, A, S])
# scores_slate.shape [B, A, S]
scores_slate = torch.take_along_dim(scores, policy.slates_indices, dim=1).reshape(
indices = policy.slates_indices.to(next_q_values.device)
next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape(
[-1, A, S]
)
# scores_slate.shape [B, A, S]
scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
# score_no_click_slate.shape: [B, A]
score_no_click_slate = torch.reshape(
torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1]
Expand All @@ -154,7 +153,7 @@ def build_slateq_losses(

clicked = torch.sum(click_indicator, dim=1)
mask_clicked_slates = clicked > 0
clicked_indices = torch.arange(batch_size)
clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device)
clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates)
# Clicked_indices is a vector and torch.gather selects the batch dimension.
q_clicked = torch.gather(replay_click_q, 0, clicked_indices)
Expand Down Expand Up @@ -297,7 +296,7 @@ def action_distribution_fn(


def get_per_slate_q_values(policy, score_no_click, scores, q_values):
indices = policy.slates_indices
indices = policy.slates_indices.to(scores.device)
A, S = policy.slates.shape
slate_q_values = torch.take_along_dim(scores * q_values, indices, dim=1).reshape(
[-1, A, S]
Expand All @@ -320,7 +319,9 @@ def score_documents(
torch.multiply(user_obs.unsqueeze(1), torch.stack(doc_obs, dim=1)), dim=2
)
# Compile a constant no-click score tensor.
score_no_click = torch.full(size=[user_obs.shape[0], 1], fill_value=no_click_score)
score_no_click = torch.full(
size=[user_obs.shape[0], 1], fill_value=no_click_score
).to(scores_per_candidate.device)
# Concatenate click and no-click scores.
all_scores = torch.cat([scores_per_candidate, score_no_click], dim=1)

Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/exploration/slate_epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_torch_exploration_action(

# Pick either random or greedy.
action = torch.where(
torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
torch.empty((batch_size,)).uniform_() < epsilon,
random_actions,
exploit_action,
)
Expand Down

0 comments on commit e765915

Please sign in to comment.