Skip to content

Commit

Permalink
[RLlib] Discussion 4986: OU Exploration (torch) crashes when restorin…
Browse files Browse the repository at this point in the history
…g from checkpoint. (ray-project#22245)
  • Loading branch information
sven1977 authored and simonsays1980 committed Feb 27, 2022
1 parent f4d3b91 commit c80684b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rllib/utils/exploration/ornstein_uhlenbeck_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def get_state(self, sess: Optional["tf.Session"] = None):
def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
if self.framework == "tf":
self.ou_state.load(state["ou_state"], session=sess)
elif isinstance(self.ou_state, np.ndarray) or (
torch and torch.is_tensor(self.ou_state)
):
elif isinstance(self.ou_state, np.ndarray):
self.ou_state = state["ou_state"]
elif torch and torch.is_tensor(self.ou_state):
self.ou_state = torch.from_numpy(state["ou_state"])
else:
self.ou_state.assign(state["ou_state"])
super().set_state(state, sess=sess)

0 comments on commit c80684b

Please sign in to comment.