From e2388c4ba7882ee65693c155eb69784e96472356 Mon Sep 17 00:00:00 2001 From: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Date: Tue, 25 Oct 2022 01:05:19 -0700 Subject: [PATCH] [RLlib] Fix: Added dtype to torch deterministic action. (#29648) Signed-off-by: Weichen Xu --- rllib/models/torch/torch_distributions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index c72dc082fed7..aaa372823cd6 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -219,8 +219,11 @@ def sample( if sample_shape is None: sample_shape = torch.Size() - loc_shape = self.loc.shape - return torch.ones(sample_shape + loc_shape, device=self.loc.device) * self.loc + + device = self.loc.device + dtype = self.loc.dtype + shape = sample_shape + self.loc.shape + return torch.ones(shape, device=device, dtype=dtype) * self.loc def rsample( self,