diff --git a/rllib/utils/filter.py b/rllib/utils/filter.py index e4f94ede5510..ff29c25d7fc4 100644 --- a/rllib/utils/filter.py +++ b/rllib/utils/filter.py @@ -79,16 +79,23 @@ def as_serializable(self) -> "NoFilter": # http://www.johndcook.com/blog/standard_deviation/ @DeveloperAPI class RunningStat: - def __init__(self, shape=None): + def __init__(self, shape=()): self.num_pushes = 0 self.mean_array = np.zeros(shape) self.std_array = np.zeros(shape) def copy(self): other = RunningStat() - other.num_pushes = self.num_pushes - other.mean_array = np.copy(self.mean_array) - other.std_array = np.copy(self.std_array) + # TODO: Remove these safe-guards if not needed anymore. + other.num_pushes = self.num_pushes if hasattr(self, "num_pushes") else self._n + other.mean_array = ( + np.copy(self.mean_array) + if hasattr(self, "mean_array") + else np.copy(self._M) + ) + other.std_array = ( + np.copy(self.std_array) if hasattr(self, "std_array") else np.copy(self._S) + ) return other def push(self, x): @@ -267,8 +274,10 @@ def sync(self, other: "MeanStdFilter") -> None: self.demean = other.demean self.destd = other.destd self.clip = other.clip + # TODO: Remove these safe-guards if not needed anymore. self.running_stats = tree.map_structure( - lambda rs: rs.copy(), other.running_stats + lambda rs: rs.copy(), + other.running_stats if hasattr(other, "running_stats") else other.rs, ) self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer)