From ec7943650d368fc46d76016edd434ac68a037754 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 9 Feb 2022 11:28:21 +0100 Subject: [PATCH] wip --- rllib/utils/exploration/ornstein_uhlenbeck_noise.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py index e4e161002642..5fe2ef7b0b31 100644 --- a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -263,10 +263,10 @@ def get_state(self, sess: Optional["tf.Session"] = None): def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None: if self.framework == "tf": self.ou_state.load(state["ou_state"], session=sess) - elif isinstance(self.ou_state, np.ndarray) or ( - torch and torch.is_tensor(self.ou_state) - ): + elif isinstance(self.ou_state, np.ndarray): self.ou_state = state["ou_state"] + elif torch and torch.is_tensor(self.ou_state): + self.ou_state = torch.from_numpy(state["ou_state"]) else: self.ou_state.assign(state["ou_state"]) super().set_state(state, sess=sess)