diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 12f725b9e10d..743cd323e7e9 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -622,7 +622,7 @@ def __init__(self, inputs, model): @override(ActionDistribution) def deterministic_sample(self) -> TensorType: - self.last_sample = nn.functional.softmax(self.dist.concentration) + self.last_sample = nn.functional.softmax(self.dist.concentration, dim=-1) return self.last_sample @override(ActionDistribution) @@ -638,10 +638,6 @@ def logp(self, x): def entropy(self): return self.dist.entropy() - @override(ActionDistribution) - def kl(self, other): - return self.dist.kl_divergence(other.dist) - @staticmethod @override(ActionDistribution) def required_model_output_shape(action_space, model_config):