-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add filters to connector pipeline #27864
Changes from all commits
5680d70
4123ac5
c62faf2
12d849f
f1d4eec
1505c28
af3f6de
4ddcdeb
6c32319
26dbe68
481fcdf
6119961
ab0bba9
0c3460d
ec262a2
ad0cc88
9da99cb
0adc85a
71f4fe2
6050961
6eb34d6
4ae9758
cc1c24e
3915fd0
df36fa3
54c6736
1098b11
9ff974e
c09a0e3
41b9bff
cee9ad5
3b76170
56fafa5
4d2cec4
e637f4f
86f1754
c89790c
06beebc
1b7f0bb
f2833d3
6abb336
9c294c1
0fd3e09
0f0461f
e6164a2
45b576e
a5eb07d
26dbf8b
de68a44
9cb9ebd
555b48d
cf00366
25db0b1
3665588
8d5fc3f
070cd91
a9ecae8
9200021
aa9c59a
a644bb4
c06a63a
40c2264
60ca25d
d166ac4
ead68d3
6492024
71038d3
8e4b8dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from typing import Any, List | ||
from gym.spaces import Discrete, MultiDiscrete | ||
|
||
import numpy as np | ||
import tree | ||
|
||
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector | ||
from ray.rllib.connectors.connector import AgentConnector | ||
from ray.rllib.connectors.connector import ( | ||
ConnectorContext, | ||
register_connector, | ||
) | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.filter import Filter | ||
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter | ||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space | ||
from ray.rllib.utils.typing import AgentConnectorDataType | ||
from ray.util.annotations import PublicAPI | ||
from ray.rllib.utils.filter import RunningStat | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector): | ||
"""A connector used to mean-std-filter observations. | ||
|
||
Incoming observations are filtered such that the output of this filter is on | ||
average zero and has a standard deviation of 1. This filtering is applied | ||
separately per element of the observation space. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ctx: ConnectorContext, | ||
demean: bool = True, | ||
destd: bool = True, | ||
clip: float = 10.0, | ||
): | ||
SyncedFilterAgentConnector.__init__(self, ctx) | ||
# We simply use the old MeanStdFilter until non-connector env_runner is fully | ||
# deprecated to avoid duplicate code | ||
|
||
filter_shape = tree.map_structure( | ||
lambda s: ( | ||
None | ||
if isinstance(s, (Discrete, MultiDiscrete)) # noqa | ||
else np.array(s.shape) | ||
), | ||
get_base_struct_from_space(ctx.observation_space), | ||
) | ||
self.filter = MeanStdFilter(filter_shape, demean=demean, destd=destd, clip=clip) | ||
|
||
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: | ||
d = ac_data.data | ||
assert ( | ||
type(d) == dict | ||
), "Single agent data must be of type Dict[str, TensorStructType]" | ||
if SampleBatch.OBS in d: | ||
d[SampleBatch.OBS] = self.filter( | ||
d[SampleBatch.OBS], update=self._is_training | ||
) | ||
if SampleBatch.NEXT_OBS in d: | ||
d[SampleBatch.NEXT_OBS] = self.filter( | ||
d[SampleBatch.NEXT_OBS], update=self._is_training | ||
) | ||
|
||
return ac_data | ||
|
||
def to_state(self): | ||
# Flattening is deterministic | ||
flattened_rs = tree.flatten(self.filter.running_stats) | ||
flattened_buffer = tree.flatten(self.filter.buffer) | ||
return MeanStdObservationFilterAgentConnector.__name__, { | ||
"shape": self.filter.shape, | ||
"no_preprocessor": self.filter.no_preprocessor, | ||
"demean": self.filter.demean, | ||
"destd": self.filter.destd, | ||
"clip": self.filter.clip, | ||
"running_stats": [s.to_state() for s in flattened_rs], | ||
"buffer": [s.to_state() for s in flattened_buffer], | ||
} | ||
|
||
# demean, destd, clip, and a state dict | ||
@staticmethod | ||
def from_state( | ||
ctx: ConnectorContext, | ||
params: List[Any] = None, | ||
demean: bool = True, | ||
destd: bool = True, | ||
clip: float = 10.0, | ||
): | ||
connector = MeanStdObservationFilterAgentConnector(ctx, demean, destd, clip) | ||
if params: | ||
connector.filter.shape = params["shape"] | ||
connector.filter.no_preprocessor = params["no_preprocessor"] | ||
connector.filter.demean = params["demean"] | ||
connector.filter.destd = params["destd"] | ||
connector.filter.clip = params["clip"] | ||
|
||
# Unflattening is deterministic | ||
running_stats = [RunningStat.from_state(s) for s in params["running_stats"]] | ||
connector.filter.running_stats = tree.unflatten_as( | ||
connector.filter.shape, running_stats | ||
) | ||
|
||
# Unflattening is deterministic | ||
buffer = [RunningStat.from_state(s) for s in params["buffer"]] | ||
connector.filter.buffer = tree.unflatten_as(connector.filter.shape, buffer) | ||
|
||
return connector | ||
|
||
def reset_state(self) -> None: | ||
"""Creates copy of current state and resets accumulated state""" | ||
if not self._is_training: | ||
raise ValueError( | ||
"State of {} can only be changed when trainin.".format(self.__name__) | ||
) | ||
self.filter.reset_buffer() | ||
|
||
def apply_changes(self, other: "Filter", *args, **kwargs) -> None: | ||
"""Updates self with state from other filter.""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
if not self._is_training: | ||
raise ValueError( | ||
"Changes can only be applied to {} when trainin.".format(self.__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait, why can't we update during inference as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I asked @kouroshHakha if you update mean-std-filters during deployment because I was wondering the same thing. He told me that it is best practice to stop after training. |
||
) | ||
return self.filter.apply_changes(other, *args, **kwargs) | ||
|
||
def copy(self) -> "Filter": | ||
"""Creates a new object with same state as self. | ||
|
||
This is a legacy Filter method that we need to keep around for now | ||
|
||
Returns: | ||
A copy of self. | ||
""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
return self.filter.copy() | ||
|
||
def sync(self, other: "AgentConnector") -> None: | ||
"""Copies all state from other filter to self.""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
if not self._is_training: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep the discussion in your first comment! |
||
raise ValueError( | ||
"{} can only be synced when trainin.".format(self.__name__) | ||
) | ||
return self.filter.sync(other.filter) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class ConcurrentMeanStdObservationFilterAgentConnector( | ||
MeanStdObservationFilterAgentConnector | ||
): | ||
"""A concurrent version of the MeanStdObservationFilterAgentConnector. | ||
|
||
This version's filter has all operations wrapped by a threading.RLock. | ||
It can therefore be safely used by multiple threads. | ||
""" | ||
|
||
def __init__(self, ctx: ConnectorContext, demean=True, destd=True, clip=10.0): | ||
SyncedFilterAgentConnector.__init__(self, ctx) | ||
# We simply use the old MeanStdFilter until non-connector env_runner is fully | ||
# deprecated to avoid duplicate code | ||
|
||
filter_shape = tree.map_structure( | ||
lambda s: ( | ||
None | ||
if isinstance(s, (Discrete, MultiDiscrete)) # noqa | ||
else np.array(s.shape) | ||
), | ||
get_base_struct_from_space(ctx.observation_space), | ||
) | ||
self.filter = ConcurrentMeanStdFilter( | ||
filter_shape, demean=True, destd=True, clip=10.0 | ||
) | ||
|
||
|
||
register_connector( | ||
MeanStdObservationFilterAgentConnector.__name__, | ||
MeanStdObservationFilterAgentConnector, | ||
) | ||
register_connector( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add some comments here describing the difference between these 2 filter connectors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
ConcurrentMeanStdObservationFilterAgentConnector.__name__, | ||
ConcurrentMeanStdObservationFilterAgentConnector, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from ray.rllib.connectors.connector import ( | ||
AgentConnector, | ||
ConnectorContext, | ||
) | ||
from ray.util.annotations import PublicAPI | ||
from ray.rllib.utils.filter import Filter | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class SyncedFilterAgentConnector(AgentConnector): | ||
"""An agent connector that filters with synchronized parameters.""" | ||
|
||
def __init__(self, ctx: ConnectorContext, *args, **kwargs): | ||
super().__init__(ctx) | ||
if args or kwargs: | ||
raise ValueError( | ||
"SyncedFilterAgentConnector does not take any additional arguments, " | ||
"but got args=`{}` and kwargs={}.".format(args, kwargs) | ||
) | ||
|
||
def apply_changes(self, other: "Filter", *args, **kwargs) -> None: | ||
"""Updates self with state from other filter.""" | ||
# TODO: (artur) inline this as soon as we deprecate ordinary filter with | ||
# non-connecto env_runner | ||
return self.filter.apply_changes(other, *args, **kwargs) | ||
|
||
def copy(self) -> "Filter": | ||
"""Creates a new object with same state as self. | ||
|
||
This is a legacy Filter method that we need to keep around for now | ||
|
||
Returns: | ||
A copy of self. | ||
""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
return self.filter.copy() | ||
|
||
def sync(self, other: "AgentConnector") -> None: | ||
"""Copies all state from other filter to self.""" | ||
# TODO: (artur) inline this as soon as we deprecate ordinary filter with | ||
# non-connector env_runner | ||
return self.filter.sync(other.filter) | ||
|
||
def reset_state(self) -> None: | ||
"""Creates copy of current state and resets accumulated state""" | ||
raise NotImplementedError | ||
|
||
def as_serializable(self) -> "Filter": | ||
# TODO: (artur) inline this as soon as we deprecate ordinary filter with | ||
# non-connector env_runner | ||
return self.filter.as_serializable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do we usually need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now the reset_state() method is part of the SyncedFilterAgentConnector interface.
Synchronization is done via the old filter synchronization mechanism, which I would like to leave in place until we switch connectors on by default. After that, I would like to simply inline all the filter code and call the connector's reset_state method directly.
Until we have gotten that far, I can delete this method if you like.