From 4168b9bdae689f92bf37c5b780bba3e788106801 Mon Sep 17 00:00:00 2001 From: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Date: Mon, 10 Apr 2023 20:50:13 -0700 Subject: [PATCH] [RLlib] Fixed a bug with kl divergence calculation of torch.Dirichlet distribution within RLlib (#34209) Signed-off-by: Kourosh Hakhamaneshi --- rllib/models/torch/torch_action_dist.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 12f725b9e10d..743cd323e7e9 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -622,7 +622,7 @@ def __init__(self, inputs, model): @override(ActionDistribution) def deterministic_sample(self) -> TensorType: - self.last_sample = nn.functional.softmax(self.dist.concentration) + self.last_sample = nn.functional.softmax(self.dist.concentration, dim=-1) return self.last_sample @override(ActionDistribution) @@ -638,10 +638,6 @@ def logp(self, x): def entropy(self): return self.dist.entropy() - @override(ActionDistribution) - def kl(self, other): - return self.dist.kl_divergence(other.dist) - @staticmethod @override(ActionDistribution) def required_model_output_shape(action_space, model_config):