forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Add filters to connector pipeline (ray-project#27864)
* initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * lint and comments Signed-off-by: Artur Niederfahrenhorst <[email protected]> * wip * format Signed-off-by: Artur Niederfahrenhorst <[email protected]> * implements and uses filters, working for ppo cartpole with meanstd Signed-off-by: Artur Niederfahrenhorst <[email protected]> * get rid of synced custom filter abstraction Signed-off-by: Artur Niederfahrenhorst <[email protected]> * add meanstd filter connetor test and minor fixes Signed-off-by: Artur Niederfahrenhorst <[email protected]> * jun's comment Signed-off-by: Artur Niederfahrenhorst <[email protected]> * move observation space struct logic because information is already in connector context Signed-off-by: Artur Niederfahrenhorst <[email protected]> * fix docstrings Signed-off-by: Artur Niederfahrenhorst <[email protected]> * minor fixes Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * fix for config=None in add_policy and connector=None Signed-off-by: Artur Niederfahrenhorst <[email protected]> * fix config name Signed-off-by: Artur Niederfahrenhorst <[email protected]> * filter connector state is now json serializable Signed-off-by: Artur Niederfahrenhorst <[email protected]> * jun's comments Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * create connectors only in create_connectors_for_policy Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * get filter connector by __get__item Signed-off-by: Artur Niederfahrenhorst <[email protected]> * remove observation filters Signed-off-by: Artur Niederfahrenhorst <[email protected]> * minor fixes Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * revert spelling error Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * Revert "Merge branch 'make_add_policy_config_explicit' into filterstoconnectors" This reverts commit 06beebc, reversing changes made to e637f4f. * accomodate case in which config={} in add_policy Signed-off-by: Artur Niederfahrenhorst <[email protected]> * fix connectors enabled not no SyncedFilterAgentConnector case Signed-off-by: Artur Niederfahrenhorst <[email protected]> * initial Signed-off-by: Artur Niederfahrenhorst <[email protected]> * merge configs in add_policy Signed-off-by: Artur Niederfahrenhorst <[email protected]> * format Signed-off-by: Artur Niederfahrenhorst <[email protected]> * revert random cloudpickle linter error Signed-off-by: Artur Niederfahrenhorst <[email protected]> * small change to trigger CI Signed-off-by: Artur Niederfahrenhorst <[email protected]> * remove all random changes outside rllib that made it into this PR Signed-off-by: Artur Niederfahrenhorst <[email protected]> * remove random rst Signed-off-by: Artur Niederfahrenhorst <[email protected]> * fix deprecated is_training call Signed-off-by: Artur Niederfahrenhorst <[email protected]> * correct in_eval call Signed-off-by: Artur Niederfahrenhorst <[email protected]> * nit Signed-off-by: Artur Niederfahrenhorst <[email protected]> * jun's comment Signed-off-by: Artur Niederfahrenhorst <[email protected]> * use merged config to create connectors Signed-off-by: Artur Niederfahrenhorst <[email protected]> * Add meaningful assertion error and switch order of if/else block to intuitive order Signed-off-by: Artur Niederfahrenhorst <[email protected]> * better warning Signed-off-by: Artur Niederfahrenhorst <[email protected]> * jun's comments Signed-off-by: Artur Niederfahrenhorst <[email protected]> * shorter function signature for helper fn Signed-off-by: Artur Niederfahrenhorst <[email protected]> Signed-off-by: Artur Niederfahrenhorst <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
- Loading branch information
1 parent
e85d7c4
commit e770151
Showing
10 changed files
with
515 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from typing import Any, List | ||
from gym.spaces import Discrete, MultiDiscrete | ||
|
||
import numpy as np | ||
import tree | ||
|
||
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector | ||
from ray.rllib.connectors.connector import AgentConnector | ||
from ray.rllib.connectors.connector import ( | ||
ConnectorContext, | ||
register_connector, | ||
) | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.filter import Filter | ||
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter | ||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space | ||
from ray.rllib.utils.typing import AgentConnectorDataType | ||
from ray.util.annotations import PublicAPI | ||
from ray.rllib.utils.filter import RunningStat | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector): | ||
"""A connector used to mean-std-filter observations. | ||
Incoming observations are filtered such that the output of this filter is on | ||
average zero and has a standard deviation of 1. This filtering is applied | ||
separately per element of the observation space. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ctx: ConnectorContext, | ||
demean: bool = True, | ||
destd: bool = True, | ||
clip: float = 10.0, | ||
): | ||
SyncedFilterAgentConnector.__init__(self, ctx) | ||
# We simply use the old MeanStdFilter until non-connector env_runner is fully | ||
# deprecated to avoid duplicate code | ||
|
||
filter_shape = tree.map_structure( | ||
lambda s: ( | ||
None | ||
if isinstance(s, (Discrete, MultiDiscrete)) # noqa | ||
else np.array(s.shape) | ||
), | ||
get_base_struct_from_space(ctx.observation_space), | ||
) | ||
self.filter = MeanStdFilter(filter_shape, demean=demean, destd=destd, clip=clip) | ||
|
||
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: | ||
d = ac_data.data | ||
assert ( | ||
type(d) == dict | ||
), "Single agent data must be of type Dict[str, TensorStructType]" | ||
if SampleBatch.OBS in d: | ||
d[SampleBatch.OBS] = self.filter( | ||
d[SampleBatch.OBS], update=self._is_training | ||
) | ||
if SampleBatch.NEXT_OBS in d: | ||
d[SampleBatch.NEXT_OBS] = self.filter( | ||
d[SampleBatch.NEXT_OBS], update=self._is_training | ||
) | ||
|
||
return ac_data | ||
|
||
def to_state(self): | ||
# Flattening is deterministic | ||
flattened_rs = tree.flatten(self.filter.running_stats) | ||
flattened_buffer = tree.flatten(self.filter.buffer) | ||
return MeanStdObservationFilterAgentConnector.__name__, { | ||
"shape": self.filter.shape, | ||
"no_preprocessor": self.filter.no_preprocessor, | ||
"demean": self.filter.demean, | ||
"destd": self.filter.destd, | ||
"clip": self.filter.clip, | ||
"running_stats": [s.to_state() for s in flattened_rs], | ||
"buffer": [s.to_state() for s in flattened_buffer], | ||
} | ||
|
||
# demean, destd, clip, and a state dict | ||
@staticmethod | ||
def from_state( | ||
ctx: ConnectorContext, | ||
params: List[Any] = None, | ||
demean: bool = True, | ||
destd: bool = True, | ||
clip: float = 10.0, | ||
): | ||
connector = MeanStdObservationFilterAgentConnector(ctx, demean, destd, clip) | ||
if params: | ||
connector.filter.shape = params["shape"] | ||
connector.filter.no_preprocessor = params["no_preprocessor"] | ||
connector.filter.demean = params["demean"] | ||
connector.filter.destd = params["destd"] | ||
connector.filter.clip = params["clip"] | ||
|
||
# Unflattening is deterministic | ||
running_stats = [RunningStat.from_state(s) for s in params["running_stats"]] | ||
connector.filter.running_stats = tree.unflatten_as( | ||
connector.filter.shape, running_stats | ||
) | ||
|
||
# Unflattening is deterministic | ||
buffer = [RunningStat.from_state(s) for s in params["buffer"]] | ||
connector.filter.buffer = tree.unflatten_as(connector.filter.shape, buffer) | ||
|
||
return connector | ||
|
||
def reset_state(self) -> None: | ||
"""Creates copy of current state and resets accumulated state""" | ||
if not self._is_training: | ||
raise ValueError( | ||
"State of {} can only be changed when trainin.".format(self.__name__) | ||
) | ||
self.filter.reset_buffer() | ||
|
||
def apply_changes(self, other: "Filter", *args, **kwargs) -> None: | ||
"""Updates self with state from other filter.""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
if not self._is_training: | ||
raise ValueError( | ||
"Changes can only be applied to {} when trainin.".format(self.__name__) | ||
) | ||
return self.filter.apply_changes(other, *args, **kwargs) | ||
|
||
def copy(self) -> "Filter": | ||
"""Creates a new object with same state as self. | ||
This is a legacy Filter method that we need to keep around for now | ||
Returns: | ||
A copy of self. | ||
""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
return self.filter.copy() | ||
|
||
def sync(self, other: "AgentConnector") -> None: | ||
"""Copies all state from other filter to self.""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
if not self._is_training: | ||
raise ValueError( | ||
"{} can only be synced when trainin.".format(self.__name__) | ||
) | ||
return self.filter.sync(other.filter) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class ConcurrentMeanStdObservationFilterAgentConnector( | ||
MeanStdObservationFilterAgentConnector | ||
): | ||
"""A concurrent version of the MeanStdObservationFilterAgentConnector. | ||
This version's filter has all operations wrapped by a threading.RLock. | ||
It can therefore be safely used by multiple threads. | ||
""" | ||
|
||
def __init__(self, ctx: ConnectorContext, demean=True, destd=True, clip=10.0): | ||
SyncedFilterAgentConnector.__init__(self, ctx) | ||
# We simply use the old MeanStdFilter until non-connector env_runner is fully | ||
# deprecated to avoid duplicate code | ||
|
||
filter_shape = tree.map_structure( | ||
lambda s: ( | ||
None | ||
if isinstance(s, (Discrete, MultiDiscrete)) # noqa | ||
else np.array(s.shape) | ||
), | ||
get_base_struct_from_space(ctx.observation_space), | ||
) | ||
self.filter = ConcurrentMeanStdFilter( | ||
filter_shape, demean=True, destd=True, clip=10.0 | ||
) | ||
|
||
|
||
register_connector( | ||
MeanStdObservationFilterAgentConnector.__name__, | ||
MeanStdObservationFilterAgentConnector, | ||
) | ||
register_connector( | ||
ConcurrentMeanStdObservationFilterAgentConnector.__name__, | ||
ConcurrentMeanStdObservationFilterAgentConnector, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from ray.rllib.connectors.connector import ( | ||
AgentConnector, | ||
ConnectorContext, | ||
) | ||
from ray.util.annotations import PublicAPI | ||
from ray.rllib.utils.filter import Filter | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class SyncedFilterAgentConnector(AgentConnector): | ||
"""An agent connector that filters with synchronized parameters.""" | ||
|
||
def __init__(self, ctx: ConnectorContext, *args, **kwargs): | ||
super().__init__(ctx) | ||
if args or kwargs: | ||
raise ValueError( | ||
"SyncedFilterAgentConnector does not take any additional arguments, " | ||
"but got args=`{}` and kwargs={}.".format(args, kwargs) | ||
) | ||
|
||
def apply_changes(self, other: "Filter", *args, **kwargs) -> None: | ||
"""Updates self with state from other filter.""" | ||
# TODO: (artur) inline this as soon as we deprecate ordinary filter with | ||
# non-connecto env_runner | ||
return self.filter.apply_changes(other, *args, **kwargs) | ||
|
||
def copy(self) -> "Filter": | ||
"""Creates a new object with same state as self. | ||
This is a legacy Filter method that we need to keep around for now | ||
Returns: | ||
A copy of self. | ||
""" | ||
# inline this as soon as we deprecate ordinary filter with non-connector | ||
# env_runner | ||
return self.filter.copy() | ||
|
||
def sync(self, other: "AgentConnector") -> None: | ||
"""Copies all state from other filter to self.""" | ||
# TODO: (artur) inline this as soon as we deprecate ordinary filter with | ||
# non-connector env_runner | ||
return self.filter.sync(other.filter) | ||
|
||
def reset_state(self) -> None: | ||
"""Creates copy of current state and resets accumulated state""" | ||
raise NotImplementedError | ||
|
||
def as_serializable(self) -> "Filter": | ||
# TODO: (artur) inline this as soon as we deprecate ordinary filter with | ||
# non-connector env_runner | ||
return self.filter.as_serializable() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.