Skip to content

Commit

Permalink
[RLlib] Fixed a bug with kl divergence calculation of torch.Dirichlet…
Browse files Browse the repository at this point in the history
… distribution within RLlib (#34209)

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
kouroshHakha authored Apr 11, 2023
1 parent f3bd6c0 commit 4168b9b
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions rllib/models/torch/torch_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def __init__(self, inputs, model):

@override(ActionDistribution)
def deterministic_sample(self) -> TensorType:
self.last_sample = nn.functional.softmax(self.dist.concentration)
self.last_sample = nn.functional.softmax(self.dist.concentration, dim=-1)
return self.last_sample

@override(ActionDistribution)
Expand All @@ -638,10 +638,6 @@ def logp(self, x):
def entropy(self):
return self.dist.entropy()

@override(ActionDistribution)
def kl(self, other):
return self.dist.kl_divergence(other.dist)

@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
Expand Down

0 comments on commit 4168b9b

Please sign in to comment.