Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add backward compatibility to MeanStdFilter to restore from older checkpoints. #30439

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions rllib/utils/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,23 @@ def as_serializable(self) -> "NoFilter":
# http://www.johndcook.com/blog/standard_deviation/
@DeveloperAPI
class RunningStat:
def __init__(self, shape=None):
def __init__(self, shape=()):
self.num_pushes = 0
self.mean_array = np.zeros(shape)
self.std_array = np.zeros(shape)

def copy(self):
other = RunningStat()
other.num_pushes = self.num_pushes
other.mean_array = np.copy(self.mean_array)
other.std_array = np.copy(self.std_array)
# TODO: Remove these safe-guards if not needed anymore.
other.num_pushes = self.num_pushes if hasattr(self, "num_pushes") else self._n
other.mean_array = (
np.copy(self.mean_array)
if hasattr(self, "mean_array")
else np.copy(self._M)
)
other.std_array = (
np.copy(self.std_array) if hasattr(self, "std_array") else np.copy(self._S)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, it was #27864 that made these changes. All changed attribute names seem to be covered here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: A TODO that tells us to remove this additional code at some point. But it does not do any harm here either 😃

Copy link
Collaborator Author

@simonsays1980 simonsays1980 Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add the nit :D

I also saw that we get often a DeprecationWarning with this class which could be overcome probably by choosing () instead of None in the shape arguments of numpy array ctors.

I could push that together into one smooth PR if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate PRs for separate problems would be appreciated!

return other

def push(self, x):
Expand Down Expand Up @@ -267,8 +274,10 @@ def sync(self, other: "MeanStdFilter") -> None:
self.demean = other.demean
self.destd = other.destd
self.clip = other.clip
# TODO: Remove these safe-guards if not needed anymore.
self.running_stats = tree.map_structure(
lambda rs: rs.copy(), other.running_stats
lambda rs: rs.copy(),
other.running_stats if hasattr(other, "running_stats") else other.rs,
)
self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer)

Expand Down