diff --git a/rllib/connectors/agent/mean_std_filter.py b/rllib/connectors/agent/mean_std_filter.py new file mode 100644 index 000000000000..cbad1f451183 --- /dev/null +++ b/rllib/connectors/agent/mean_std_filter.py @@ -0,0 +1,187 @@ +from typing import Any, List +from gym.spaces import Discrete, MultiDiscrete + +import numpy as np +import tree + +from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector +from ray.rllib.connectors.connector import AgentConnector +from ray.rllib.connectors.connector import ( + ConnectorContext, + register_connector, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.filter import Filter +from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import AgentConnectorDataType +from ray.util.annotations import PublicAPI +from ray.rllib.utils.filter import RunningStat + + +@PublicAPI(stability="alpha") +class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector): + """A connector used to mean-std-filter observations. + + Incoming observations are filtered such that the output of this filter is on + average zero and has a standard deviation of 1. This filtering is applied + separately per element of the observation space. + """ + + def __init__( + self, + ctx: ConnectorContext, + demean: bool = True, + destd: bool = True, + clip: float = 10.0, + ): + SyncedFilterAgentConnector.__init__(self, ctx) + # We simply use the old MeanStdFilter until non-connector env_runner is fully + # deprecated to avoid duplicate code + + filter_shape = tree.map_structure( + lambda s: ( + None + if isinstance(s, (Discrete, MultiDiscrete)) # noqa + else np.array(s.shape) + ), + get_base_struct_from_space(ctx.observation_space), + ) + self.filter = MeanStdFilter(filter_shape, demean=demean, destd=destd, clip=clip) + + def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: + d = ac_data.data + assert ( + type(d) == dict + ), "Single agent data must be of type Dict[str, TensorStructType]" + if SampleBatch.OBS in d: + d[SampleBatch.OBS] = self.filter( + d[SampleBatch.OBS], update=self._is_training + ) + if SampleBatch.NEXT_OBS in d: + d[SampleBatch.NEXT_OBS] = self.filter( + d[SampleBatch.NEXT_OBS], update=self._is_training + ) + + return ac_data + + def to_state(self): + # Flattening is deterministic + flattened_rs = tree.flatten(self.filter.running_stats) + flattened_buffer = tree.flatten(self.filter.buffer) + return MeanStdObservationFilterAgentConnector.__name__, { + "shape": self.filter.shape, + "no_preprocessor": self.filter.no_preprocessor, + "demean": self.filter.demean, + "destd": self.filter.destd, + "clip": self.filter.clip, + "running_stats": [s.to_state() for s in flattened_rs], + "buffer": [s.to_state() for s in flattened_buffer], + } + + # demean, destd, clip, and a state dict + @staticmethod + def from_state( + ctx: ConnectorContext, + params: List[Any] = None, + demean: bool = True, + destd: bool = True, + clip: float = 10.0, + ): + connector = MeanStdObservationFilterAgentConnector(ctx, demean, destd, clip) + if params: + connector.filter.shape = params["shape"] + connector.filter.no_preprocessor = params["no_preprocessor"] + connector.filter.demean = params["demean"] + connector.filter.destd = params["destd"] + connector.filter.clip = params["clip"] + + # Unflattening is deterministic + running_stats = [RunningStat.from_state(s) for s in params["running_stats"]] + connector.filter.running_stats = tree.unflatten_as( + connector.filter.shape, running_stats + ) + + # Unflattening is deterministic + buffer = [RunningStat.from_state(s) for s in params["buffer"]] + connector.filter.buffer = tree.unflatten_as(connector.filter.shape, buffer) + + return connector + + def reset_state(self) -> None: + """Creates copy of current state and resets accumulated state""" + if not self._is_training: + raise ValueError( + "State of {} can only be changed when trainin.".format(self.__name__) + ) + self.filter.reset_buffer() + + def apply_changes(self, other: "Filter", *args, **kwargs) -> None: + """Updates self with state from other filter.""" + # inline this as soon as we deprecate ordinary filter with non-connector + # env_runner + if not self._is_training: + raise ValueError( + "Changes can only be applied to {} when trainin.".format(self.__name__) + ) + return self.filter.apply_changes(other, *args, **kwargs) + + def copy(self) -> "Filter": + """Creates a new object with same state as self. + + This is a legacy Filter method that we need to keep around for now + + Returns: + A copy of self. + """ + # inline this as soon as we deprecate ordinary filter with non-connector + # env_runner + return self.filter.copy() + + def sync(self, other: "AgentConnector") -> None: + """Copies all state from other filter to self.""" + # inline this as soon as we deprecate ordinary filter with non-connector + # env_runner + if not self._is_training: + raise ValueError( + "{} can only be synced when trainin.".format(self.__name__) + ) + return self.filter.sync(other.filter) + + +@PublicAPI(stability="alpha") +class ConcurrentMeanStdObservationFilterAgentConnector( + MeanStdObservationFilterAgentConnector +): + """A concurrent version of the MeanStdObservationFilterAgentConnector. + + This version's filter has all operations wrapped by a threading.RLock. + It can therefore be safely used by multiple threads. + """ + + def __init__(self, ctx: ConnectorContext, demean=True, destd=True, clip=10.0): + SyncedFilterAgentConnector.__init__(self, ctx) + # We simply use the old MeanStdFilter until non-connector env_runner is fully + # deprecated to avoid duplicate code + + filter_shape = tree.map_structure( + lambda s: ( + None + if isinstance(s, (Discrete, MultiDiscrete)) # noqa + else np.array(s.shape) + ), + get_base_struct_from_space(ctx.observation_space), + ) + self.filter = ConcurrentMeanStdFilter( + filter_shape, demean=True, destd=True, clip=10.0 + ) + + +register_connector( + MeanStdObservationFilterAgentConnector.__name__, + MeanStdObservationFilterAgentConnector, +) +register_connector( + ConcurrentMeanStdObservationFilterAgentConnector.__name__, + ConcurrentMeanStdObservationFilterAgentConnector, +) diff --git a/rllib/connectors/agent/synced_filter.py b/rllib/connectors/agent/synced_filter.py new file mode 100644 index 000000000000..08d147fbcb81 --- /dev/null +++ b/rllib/connectors/agent/synced_filter.py @@ -0,0 +1,52 @@ +from ray.rllib.connectors.connector import ( + AgentConnector, + ConnectorContext, +) +from ray.util.annotations import PublicAPI +from ray.rllib.utils.filter import Filter + + +@PublicAPI(stability="alpha") +class SyncedFilterAgentConnector(AgentConnector): + """An agent connector that filters with synchronized parameters.""" + + def __init__(self, ctx: ConnectorContext, *args, **kwargs): + super().__init__(ctx) + if args or kwargs: + raise ValueError( + "SyncedFilterAgentConnector does not take any additional arguments, " + "but got args=`{}` and kwargs={}.".format(args, kwargs) + ) + + def apply_changes(self, other: "Filter", *args, **kwargs) -> None: + """Updates self with state from other filter.""" + # TODO: (artur) inline this as soon as we deprecate ordinary filter with + # non-connecto env_runner + return self.filter.apply_changes(other, *args, **kwargs) + + def copy(self) -> "Filter": + """Creates a new object with same state as self. + + This is a legacy Filter method that we need to keep around for now + + Returns: + A copy of self. + """ + # inline this as soon as we deprecate ordinary filter with non-connector + # env_runner + return self.filter.copy() + + def sync(self, other: "AgentConnector") -> None: + """Copies all state from other filter to self.""" + # TODO: (artur) inline this as soon as we deprecate ordinary filter with + # non-connector env_runner + return self.filter.sync(other.filter) + + def reset_state(self) -> None: + """Creates copy of current state and resets accumulated state""" + raise NotImplementedError + + def as_serializable(self) -> "Filter": + # TODO: (artur) inline this as soon as we deprecate ordinary filter with + # non-connector env_runner + return self.filter.as_serializable() diff --git a/rllib/connectors/tests/test_agent.py b/rllib/connectors/tests/test_agent.py index d233d5c6cb71..f720d497fa13 100644 --- a/rllib/connectors/tests/test_agent.py +++ b/rllib/connectors/tests/test_agent.py @@ -1,6 +1,7 @@ import gym import numpy as np import unittest +from gym.spaces import Box from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector @@ -18,6 +19,9 @@ AgentConnectorDataType, AgentConnectorsOutput, ) +from ray.rllib.connectors.agent.mean_std_filter import ( + MeanStdObservationFilterAgentConnector, +) class TestAgentConnector(unittest.TestCase): @@ -473,6 +477,89 @@ def test_vr_connector_only_keeps_useful_timesteps(self): # Data matches the latest timestep. self.assertTrue(np.array_equal(obs_data[0], np.array([4, 5, 6, 7]))) + def test_mean_std_observation_filter_connector(self): + for bounds in [ + (-1, 1), # normalized + (-2, 2), # scaled + (0, 2), # shifted + (0, 4), # scaled and shifted + ]: + print("Testing uniform sampling with bounds: {}".format(bounds)) + + observation_space = Box(bounds[0], bounds[1], (3, 64, 64)) + ctx = ConnectorContext(observation_space=observation_space) + filter_connector = MeanStdObservationFilterAgentConnector(ctx) + + # Warm up Mean-Std filter + for i in range(1000): + obs = observation_space.sample() + sample_batch = { + SampleBatch.NEXT_OBS: obs, + } + ac = AgentConnectorDataType(0, 0, sample_batch) + filter_connector.transform(ac) + + # Create another connector to set state to + _, state = filter_connector.to_state() + another_filter_connector = ( + MeanStdObservationFilterAgentConnector.from_state(ctx, state) + ) + + another_filter_connector.in_eval() + + # Collector transformed observations + transformed_observations = [] + for i in range(1000): + obs = observation_space.sample() + sample_batch = { + SampleBatch.NEXT_OBS: obs, + } + ac = AgentConnectorDataType(0, 0, sample_batch) + connector_output = another_filter_connector.transform(ac) + transformed_observations.append( + connector_output.data[SampleBatch.NEXT_OBS] + ) + + # Check if transformed observations are actually mean-std filtered + self.assertTrue( + np.isclose(np.mean(transformed_observations), 0, atol=0.001) + ) + self.assertTrue(np.isclose(np.var(transformed_observations), 1, atol=0.01)) + + # Check if filter parameters where frozen because we are not training + self.assertTrue( + filter_connector.filter.running_stats.num_pushes + == another_filter_connector.filter.running_stats.num_pushes, + ) + self.assertTrue( + np.all( + filter_connector.filter.running_stats.mean_array + == another_filter_connector.filter.running_stats.mean_array, + ) + ) + self.assertTrue( + np.all( + filter_connector.filter.running_stats.std_array + == another_filter_connector.filter.running_stats.std_array, + ) + ) + self.assertTrue( + filter_connector.filter.buffer.num_pushes + == another_filter_connector.filter.buffer.num_pushes, + ) + self.assertTrue( + np.all( + filter_connector.filter.buffer.mean_array + == another_filter_connector.filter.buffer.mean_array, + ) + ) + self.assertTrue( + np.all( + filter_connector.filter.buffer.std_array + == another_filter_connector.filter.buffer.std_array, + ) + ) + if __name__ == "__main__": import sys diff --git a/rllib/connectors/util.py b/rllib/connectors/util.py index 639ada8bda35..502f4c582d49 100644 --- a/rllib/connectors/util.py +++ b/rllib/connectors/util.py @@ -12,8 +12,13 @@ from ray.rllib.connectors.agent.state_buffer import StateBufferConnector from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector from ray.rllib.connectors.connector import Connector, ConnectorContext, get_connector +from ray.rllib.connectors.agent.mean_std_filter import ( + MeanStdObservationFilterAgentConnector, + ConcurrentMeanStdObservationFilterAgentConnector, +) from ray.rllib.utils.typing import TrainerConfigDict -from ray.util.annotations import PublicAPI +from ray.util.annotations import PublicAPI, DeveloperAPI +from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector if TYPE_CHECKING: from ray.rllib.policy.policy import Policy @@ -38,6 +43,14 @@ def get_agent_connectors_from_config( if not config["_disable_preprocessor_api"]: connectors.append(ObsPreprocessorConnector(ctx)) + # Filters should be after observation preprocessing + filter_connector = get_synced_filter_connector( + ctx, + ) + # Configuration option "NoFilter" results in `filter_connector==None`. + if filter_connector: + connectors.append(filter_connector) + connectors.extend( [ StateBufferConnector(ctx), @@ -78,7 +91,11 @@ def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) - assert policy.agent_connectors is None and policy.agent_connectors is None + assert policy.agent_connectors is None and policy.agent_connectors is None, ( + "Can not create connectors for a policy that already has connectors. This " + "can happen if you add a Policy that has connectors attached to a " + "RolloutWorker with add_policy()." + ) policy.agent_connectors = get_agent_connectors_from_config(ctx, config) policy.action_connectors = get_action_connectors_from_config(ctx, config) @@ -101,3 +118,34 @@ def restore_connectors_for_policy( ctx: ConnectorContext = ConnectorContext.from_policy(policy) name, params = connector_config return get_connector(ctx, name, params) + + +# We need this filter selection mechanism temporarily to remain compatible to old API +@DeveloperAPI +def get_synced_filter_connector(ctx: ConnectorContext): + filter_specifier = ctx.config.get("observation_filter") + if filter_specifier == "MeanStdFilter": + return MeanStdObservationFilterAgentConnector(ctx, clip=None) + elif filter_specifier == "ConcurrentMeanStdFilter": + return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None) + elif filter_specifier == "NoFilter": + return None + else: + raise Exception("Unknown observation_filter: " + str(filter_specifier)) + + +@DeveloperAPI +def maybe_get_filters_for_syncing(rollout_worker, policy_id): + # As long as the historic filter synchronization mechanism is in + # place, we need to put filters into self.filters so that they get + # synchronized + filter_connectors = rollout_worker.policy_map[policy_id].agent_connectors[ + SyncedFilterAgentConnector + ] + # There can only be one filter at a time + if filter_connectors: + assert len(SyncedFilterAgentConnector) == 1, ( + "ConnectorPipeline has two connectors of type " + "SyncedFilterAgentConnector but can only have one." + ) + rollout_worker.filters[policy_id] = filter_connectors[0].filter diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 7090029c4318..bafc787437c6 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -24,7 +24,10 @@ import ray from ray import ObjectRef from ray import cloudpickle as pickle -from ray.rllib.connectors.util import create_connectors_for_policy +from ray.rllib.connectors.util import ( + create_connectors_for_policy, + maybe_get_filters_for_syncing, +) from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.env_context import EnvContext from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv @@ -224,7 +227,6 @@ def __init__( compress_observations: bool = False, num_envs: int = 1, observation_fn: Optional["ObservationFunction"] = None, - observation_filter: str = "NoFilter", clip_rewards: Optional[Union[bool, float]] = None, normalize_actions: bool = True, clip_actions: bool = False, @@ -307,7 +309,6 @@ def __init__( and vectorize the computation of actions. This has no effect if if the env already implements VectorEnv. observation_fn: Optional multi-agent observation function. - observation_filter: Name of observation filter to use. clip_rewards: True for clipping rewards to [-1.0, 1.0] prior to experience postprocessing. None: Clip for Atari only. float: Clip to [-clip_rewards; +clip_rewards]. @@ -457,7 +458,6 @@ def gen_rollouts(): self.preprocessing_enabled: bool = not policy_config.get( "_disable_preprocessor_api" ) - self.observation_filter = observation_filter self.last_batch: Optional[SampleBatchType] = None self.global_vars: Optional[dict] = None self.fake_sampler: bool = fake_sampler @@ -640,15 +640,21 @@ def wrap(env): # TODO(jungong) : clean up after non-connector env_runner is fully deprecated. self.filters: Dict[PolicyID, Filter] = {} for (policy_id, policy) in self.policy_map.items(): - filter_shape = tree.map_structure( - lambda s: ( - None - if isinstance(s, (Discrete, MultiDiscrete)) # noqa - else np.array(s.shape) - ), - policy.observation_space_struct, - ) - self.filters[policy_id] = get_filter(self.observation_filter, filter_shape) + if not policy_config.get("enable_connectors"): + filter_shape = tree.map_structure( + lambda s: ( + None + if isinstance(s, (Discrete, MultiDiscrete)) # noqa + else np.array(s.shape) + ), + policy.observation_space_struct, + ) + self.filters[policy_id] = get_filter( + self.policy_map[policy_id].config.get( + "observation_filter", "NoFilter" + ), + filter_shape, + ) if self.worker_index == 0: logger.info("Built filter map: {}".format(self.filters)) @@ -1013,7 +1019,7 @@ def compute_gradients( if log_once("compute_gradients"): logger.info("Compute gradients on:\n\n{}\n".format(summarize(samples))) - # Backward compatiblity for A2C: Single-agent only (ComputeGradients execution + # Backward compatibility for A2C: Single-agent only (ComputeGradients execution # op must not return multi-agent dict b/c of A2C's `.batch()` in the execution # plan; this would "batch" over the "default_policy" keys instead of the data). if single_agent is True: @@ -1272,22 +1278,25 @@ def add_policy( if policy_state: new_policy.set_state(policy_state) - filter_shape = tree.map_structure( - lambda s: ( - None - if isinstance(s, (Discrete, MultiDiscrete)) # noqa - else np.array(s.shape) - ), - new_policy.observation_space_struct, - ) + connectors_enabled = merged_config.get("enable_connectors", False) - self.filters[policy_id] = get_filter(self.observation_filter, filter_shape) + if connectors_enabled: + policy = self.policy_map[policy_id] + create_connectors_for_policy(policy, merged_config) + maybe_get_filters_for_syncing(self, policy_id) + else: + filter_shape = tree.map_structure( + lambda s: ( + None + if isinstance(s, (Discrete, MultiDiscrete)) # noqa + else np.array(s.shape) + ), + new_policy.observation_space_struct, + ) - # Create connectors for the new policy, if necessary. - # Only if connectors are enables and we created the new policy from scratch - # (it was not provided to us via the `policy` arg. - if policy is None and self.policy_config.get("enable_connectors"): - create_connectors_for_policy(new_policy, merged_config) + self.filters[policy_id] = get_filter( + (config or {}).get("observation_filter", "NoFilter"), filter_shape + ) self.set_policy_mapping_fn(policy_mapping_fn) if policies_to_train is not None: @@ -1860,6 +1869,7 @@ def _build_policy_map( if connectors_enabled and name in self.policy_map: create_connectors_for_policy(self.policy_map[name], policy_config) + maybe_get_filters_for_syncing(self, name) if name in self.policy_map: self.callbacks.on_create_policy( diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 0d57cb67b854..04f834b68601 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -774,13 +774,13 @@ def test_filter_sync(self): env_creator=lambda _: gym.make("CartPole-v0"), policy_spec=MockPolicy, sample_async=True, - observation_filter="ConcurrentMeanStdFilter", + policy_config={"observation_filter": "ConcurrentMeanStdFilter"}, ) time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) obs_f = filters[DEFAULT_POLICY_ID] - self.assertNotEqual(obs_f.rs.n, 0) + self.assertNotEqual(obs_f.running_stats.n, 0) self.assertNotEqual(obs_f.buffer.n, 0) ev.stop() @@ -789,7 +789,7 @@ def test_get_filters(self): env_creator=lambda _: gym.make("CartPole-v0"), policy_spec=MockPolicy, sample_async=True, - observation_filter="ConcurrentMeanStdFilter", + policy_config={"observation_filter": "ConcurrentMeanStdFilter"}, ) self.sample_and_flush(ev) filters = ev.get_filters(flush_after=False) @@ -797,7 +797,7 @@ def test_get_filters(self): filters2 = ev.get_filters(flush_after=False) obs_f = filters[DEFAULT_POLICY_ID] obs_f2 = filters2[DEFAULT_POLICY_ID] - self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n) + self.assertGreaterEqual(obs_f2.running_stats.n, obs_f.running_stats.n) self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n) ev.stop() @@ -806,7 +806,7 @@ def test_sync_filter(self): env_creator=lambda _: gym.make("CartPole-v0"), policy_spec=MockPolicy, sample_async=True, - observation_filter="ConcurrentMeanStdFilter", + policy_config={"observation_filter": "ConcurrentMeanStdFilter"}, ) obs_f = self.sample_and_flush(ev) @@ -817,11 +817,11 @@ def test_sync_filter(self): self.assertLessEqual(obs_f.buffer.n, 20) new_obsf = obs_f.copy() - new_obsf.rs._n = 100 + new_obsf.running_stats.num_pushes = 100 ev.sync_filters({DEFAULT_POLICY_ID: new_obsf}) filters = ev.get_filters(flush_after=False) obs_f = filters[DEFAULT_POLICY_ID] - self.assertGreaterEqual(obs_f.rs.n, 100) + self.assertGreaterEqual(obs_f.running_stats.n, 100) self.assertLessEqual(obs_f.buffer.n, 20) ev.stop() @@ -957,7 +957,7 @@ def sample_and_flush(self, ev): ev.sample() filters = ev.get_filters(flush_after=True) obs_f = filters[DEFAULT_POLICY_ID] - self.assertNotEqual(obs_f.rs.n, 0) + self.assertNotEqual(obs_f.running_stats.n, 0) self.assertNotEqual(obs_f.buffer.n, 0) return obs_f diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 02b119ae02ec..118776b81aa6 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -902,7 +902,6 @@ def valid_module(class_path): compress_observations=config["compress_observations"], num_envs=config["num_envs_per_worker"], observation_fn=config["multiagent"]["observation_fn"], - observation_filter=config["observation_filter"], clip_rewards=config["clip_rewards"], normalize_actions=config["normalize_actions"], clip_actions=config["clip_actions"], diff --git a/rllib/examples/custom_observation_filters.py b/rllib/examples/custom_observation_filters.py index 4f6d8a40c3e8..0da66e7922c2 100644 --- a/rllib/examples/custom_observation_filters.py +++ b/rllib/examples/custom_observation_filters.py @@ -38,7 +38,7 @@ def push(self, x): def update(self, other): n1 = self._n - n2 = other._n + n2 = other.num_pushes n = n1 + n2 if n == 0: return diff --git a/rllib/tests/test_filters.py b/rllib/tests/test_filters.py index b641d01d4408..03c08c971392 100644 --- a/rllib/tests/test_filters.py +++ b/rllib/tests/test_filters.py @@ -47,12 +47,12 @@ def testBasic(self): filt = MeanStdFilter(shape) for i in range(5): filt(np.ones(shape)) - self.assertEqual(filt.rs.n, 5) + self.assertEqual(filt.running_stats.n, 5) self.assertEqual(filt.buffer.n, 5) filt2 = MeanStdFilter(shape) filt2.sync(filt) - self.assertEqual(filt2.rs.n, 5) + self.assertEqual(filt2.running_stats.n, 5) self.assertEqual(filt2.buffer.n, 5) filt.reset_buffer() @@ -61,11 +61,11 @@ def testBasic(self): filt.apply_changes(filt2, with_buffer=False) self.assertEqual(filt.buffer.n, 0) - self.assertEqual(filt.rs.n, 10) + self.assertEqual(filt.running_stats.n, 10) filt.apply_changes(filt2, with_buffer=True) self.assertEqual(filt.buffer.n, 5) - self.assertEqual(filt.rs.n, 15) + self.assertEqual(filt.running_stats.n, 15) class FilterManagerTest(unittest.TestCase): @@ -82,7 +82,7 @@ def test_synchronize(self): filt1 = MeanStdFilter(()) for i in range(10): filt1(i) - self.assertEqual(filt1.rs.n, 10) + self.assertEqual(filt1.running_stats.n, 10) filt1.reset_buffer() self.assertEqual(filt1.buffer.n, 0) @@ -96,9 +96,9 @@ def test_synchronize(self): filters = ray.get(remote_e.get_filters.remote()) obs_f = filters["obs_filter"] - self.assertEqual(filt1.rs.n, 20) + self.assertEqual(filt1.running_stats.n, 20) self.assertEqual(filt1.buffer.n, 0) - self.assertEqual(obs_f.rs.n, filt1.rs.n) + self.assertEqual(obs_f.running_stats.n, filt1.running_stats.n) self.assertEqual(obs_f.buffer.n, filt1.buffer.n) diff --git a/rllib/utils/filter.py b/rllib/utils/filter.py index 246db4ccb376..69d480861bf7 100644 --- a/rllib/utils/filter.py +++ b/rllib/utils/filter.py @@ -8,13 +8,12 @@ from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.numpy import SMALL_NUMBER from ray.rllib.utils.typing import TensorStructType +from ray.rllib.utils.serialization import _serialize_ndarray, _deserialize_ndarray +from ray.rllib.utils.deprecation import deprecation_warning logger = logging.getLogger(__name__) -# TODO(jungong) : Add Adapters to use these filters as agent connectors. - - @DeveloperAPI class Filter: """Processes input, possibly statefully.""" @@ -81,49 +80,50 @@ def as_serializable(self) -> "NoFilter": @DeveloperAPI class RunningStat: def __init__(self, shape=None): - self._n = 0 - self._M = np.zeros(shape) - self._S = np.zeros(shape) + self.num_pushes = 0 + self.mean_array = np.zeros(shape) + self.std_array = np.zeros(shape) def copy(self): other = RunningStat() - other._n = self._n - other._M = np.copy(self._M) - other._S = np.copy(self._S) + other.num_pushes = self.num_pushes + other.mean_array = np.copy(self.mean_array) + other.std_array = np.copy(self.std_array) return other def push(self, x): x = np.asarray(x) # Unvectorized update of the running statistics. - if x.shape != self._M.shape: + if x.shape != self.mean_array.shape: raise ValueError( "Unexpected input shape {}, expected {}, value = {}".format( - x.shape, self._M.shape, x + x.shape, self.mean_array.shape, x ) ) - n1 = self._n - self._n += 1 - if self._n == 1: - self._M[...] = x + self.num_pushes += 1 + if self.num_pushes == 1: + self.mean_array[...] = x else: - delta = x - self._M - self._M[...] += delta / self._n - self._S[...] += delta * delta * n1 / self._n + delta = x - self.mean_array + self.mean_array[...] += delta / self.num_pushes + self.std_array[...] += ( + delta * delta * (self.num_pushes - 1) / self.num_pushes + ) def update(self, other): - n1 = self._n - n2 = other._n + n1 = self.num_pushes + n2 = other.num_pushes n = n1 + n2 if n == 0: # Avoid divide by zero, which creates nans return - delta = self._M - other._M + delta = self.mean_array - other.mean_array delta2 = delta * delta - M = (n1 * self._M + n2 * other._M) / n - S = self._S + other._S + delta2 * n1 * n2 / n - self._n = n - self._M = M - self._S = S + m = (n1 * self.mean_array + n2 * other.mean_array) / n + s = self.std_array + other.std_array + delta2 * n1 * n2 / n + self.num_pushes = n + self.mean_array = m + self.std_array = s def __repr__(self): return "(n={}, mean_mean={}, mean_std={})".format( @@ -132,15 +132,19 @@ def __repr__(self): @property def n(self): - return self._n + return self.num_pushes @property def mean(self): - return self._M + return self.mean_array @property def var(self): - return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) + return ( + self.std_array / (self.num_pushes - 1) + if self.num_pushes > 1 + else np.square(self.mean_array) + ) @property def std(self): @@ -148,7 +152,22 @@ def std(self): @property def shape(self): - return self._M.shape + return self.mean_array.shape + + def to_state(self): + return { + "num_pushes": self.num_pushes, + "mean_array": _serialize_ndarray(self.mean_array), + "std_array": _serialize_ndarray(self.std_array), + } + + @staticmethod + def from_state(state): + running_stats = RunningStat() + running_stats.num_pushes = state["num_pushes"] + running_stats.mean_array = _deserialize_ndarray(state["mean_array"]) + running_stats.std_array = _deserialize_ndarray(state["std_array"]) + return running_stats @DeveloperAPI @@ -168,8 +187,8 @@ def __init__(self, shape, demean=True, destd=True, clip=10.0): and len(flat_shape) > 0 and isinstance(flat_shape[0], np.ndarray) ) - # If preprocessing (flattning dicts/tuples), make sure shape - # is an np.ndarray so we don't confuse it with a complex Tuple + # If preprocessing (flattening dicts/tuples), make sure shape + # is an np.ndarray, so we don't confuse it with a complex Tuple # space's shape structure (which is a Tuple[np.ndarray]). if not self.no_preprocessor: self.shape = np.array(self.shape) @@ -177,7 +196,7 @@ def __init__(self, shape, demean=True, destd=True, clip=10.0): self.destd = destd self.clip = clip # Running stats. - self.rs = tree.map_structure(lambda s: RunningStat(s), self.shape) + self.running_stats = tree.map_structure(lambda s: RunningStat(s), self.shape) # In distributed rollouts, each worker sees different states. # The buffer is used to keep track of deltas amongst all the @@ -202,19 +221,19 @@ def apply_changes( >>> a = MeanStdFilter(()) >>> a(1) >>> a(2) - >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + >>> print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) [2, 1.5, 2] >>> b = MeanStdFilter(()) >>> b(10) >>> a.apply_changes(b, with_buffer=False) - >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + >>> print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) [3, 4.333333333333333, 2] >>> a.apply_changes(b, with_buffer=True) - >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + >>> print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) [4, 5.75, 1] """ tree.map_structure( - lambda rs, other_rs: rs.update(other_rs), self.rs, other.buffer + lambda rs, other_rs: rs.update(other_rs), self.running_stats, other.buffer ) if with_buffer: self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer) @@ -235,20 +254,22 @@ def sync(self, other: "MeanStdFilter") -> None: >>> a = MeanStdFilter(()) >>> a(1) >>> a(2) - >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + >>> print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) [2, array(1.5), 2] >>> b = MeanStdFilter(()) >>> b(10) - >>> print([b.rs.n, b.rs.mean, b.buffer.n]) + >>> print([b.running_stats.n, b.running_stats.mean, b.buffer.n]) [1, array(10.0), 1] >>> a.sync(b) - >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + >>> print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) [1, array(10.0), 1] """ self.demean = other.demean self.destd = other.destd self.clip = other.clip - self.rs = tree.map_structure(lambda rs: rs.copy(), other.rs) + self.running_stats = tree.map_structure( + lambda rs: rs.copy(), other.running_stats + ) self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer) def __call__(self, x: TensorStructType, update: bool = True) -> TensorStructType: @@ -285,14 +306,19 @@ def _helper(x, rs, buffer, shape): if self.no_preprocessor: return tree.map_structure_up_to( - x, _helper, x, self.rs, self.buffer, self.shape + x, _helper, x, self.running_stats, self.buffer, self.shape ) else: - return _helper(x, self.rs, self.buffer, self.shape) + return _helper(x, self.running_stats, self.buffer, self.shape) def __repr__(self) -> str: return "MeanStdFilter({}, {}, {}, {}, {}, {})".format( - self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer + self.shape, + self.demean, + self.destd, + self.clip, + self.running_stats, + self.buffer, ) @@ -302,6 +328,15 @@ class ConcurrentMeanStdFilter(MeanStdFilter): def __init__(self, *args, **kwargs): super(ConcurrentMeanStdFilter, self).__init__(*args, **kwargs) + deprecation_warning( + old="ConcurrentMeanStdFilter", + error=False, + help="ConcurrentMeanStd filters are only used for testing and will " + "therefore be deprecated in the course of moving to the " + "Connetors API, where testing of filters will be done by other " + "means.", + ) + self._lock = threading.RLock() def lock_wrap(func): @@ -327,7 +362,12 @@ def copy(self) -> "ConcurrentMeanStdFilter": def __repr__(self) -> str: return "ConcurrentMeanStdFilter({}, {}, {}, {}, {}, {})".format( - self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer + self.shape, + self.demean, + self.destd, + self.clip, + self.running_stats, + self.buffer, )