Skip to content

Commit

Permalink
[RLlib] Fix: Added dtype to torch deterministic action. (ray-project#…
Browse files Browse the repository at this point in the history
…29648)

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
kouroshHakha authored and WeichenXu123 committed Dec 19, 2022
1 parent 67625d6 commit e2388c4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions rllib/models/torch/torch_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,11 @@ def sample(

if sample_shape is None:
sample_shape = torch.Size()
loc_shape = self.loc.shape
return torch.ones(sample_shape + loc_shape, device=self.loc.device) * self.loc

device = self.loc.device
dtype = self.loc.dtype
shape = sample_shape + self.loc.shape
return torch.ones(shape, device=device, dtype=dtype) * self.loc

def rsample(
self,
Expand Down

0 comments on commit e2388c4

Please sign in to comment.