Skip to content

Commit

Permalink
[RLlib] DQN (Rainbow): Fix torch noisy layer support and loss (#16716)
Browse files Browse the repository at this point in the history
  • Loading branch information
gbartyzel authored Jul 13, 2021
1 parent 1fd0eb8 commit d553d4d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 7 additions & 1 deletion rllib/agents/dqn/dqn_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ def __init__(

# Value layer (nodes=1).
if self.dueling:
value_module.add_module("V", SlimFC(ins, 1, activation_fn=None))
if use_noisy:
value_module.add_module(
"V",
NoisyLayer(ins, self.num_atoms, sigma0, activation=None))
elif q_hiddens:
value_module.add_module(
"V", SlimFC(ins, self.num_atoms, activation_fn=None))
self.value_module = value_module

def get_q_value_distributions(self, model_out):
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self,
# Indispensable judgement which is missed in most implementations
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
# be discarded because (ub-b) == (b-lb) == 0.
floor_equal_ceil = (ub - lb < 0.5).float()
floor_equal_ceil = ((ub - lb) < 0.5).float()

# (batch_size, num_atoms, num_atoms)
l_project = F.one_hot(lb.long(), num_atoms)
Expand All @@ -79,7 +79,7 @@ def __init__(self,
# Rainbow paper claims that using this cross entropy loss for
# priority is robust and insensitive to `prioritized_replay_alpha`
self.td_error = softmax_cross_entropy_with_logits(
logits=q_logits_t_selected, labels=m)
logits=q_logits_t_selected, labels=m.detach())
self.loss = torch.mean(self.td_error * importance_weights)
self.stats = {
# TODO: better Q stats for dist dqn
Expand Down

0 comments on commit d553d4d

Please sign in to comment.