diff --git a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py index e4e161002642..5fe2ef7b0b31 100644 --- a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -263,10 +263,10 @@ def get_state(self, sess: Optional["tf.Session"] = None): def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None: if self.framework == "tf": self.ou_state.load(state["ou_state"], session=sess) - elif isinstance(self.ou_state, np.ndarray) or ( - torch and torch.is_tensor(self.ou_state) - ): + elif isinstance(self.ou_state, np.ndarray): self.ou_state = state["ou_state"] + elif torch and torch.is_tensor(self.ou_state): + self.ou_state = torch.from_numpy(state["ou_state"]) else: self.ou_state.assign(state["ou_state"]) super().set_state(state, sess=sess)