Skip to content

Commit

Permalink
[RLlib] Add filters to connector pipeline (ray-project#27864)
Browse files Browse the repository at this point in the history
* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* lint and comments

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* wip

* format

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* implements and uses filters, working for ppo cartpole with meanstd

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* get rid of synced custom filter abstraction

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* add meanstd filter connetor test and minor fixes

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* jun's comment

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* move observation space struct logic because information is already in connector context

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* fix docstrings

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* minor fixes

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* fix for config=None in add_policy and connector=None

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* fix config name

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* filter connector state is now json serializable

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* jun's comments

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* create connectors only in create_connectors_for_policy

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* get filter connector by __get__item

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* remove observation filters

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* minor fixes

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* revert spelling error

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* Revert "Merge branch 'make_add_policy_config_explicit' into filterstoconnectors"

This reverts commit 06beebc, reversing
changes made to e637f4f.

* accomodate case in which config={} in add_policy

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* fix connectors enabled not no SyncedFilterAgentConnector case

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* initial

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* merge configs in add_policy

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* format

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* revert random cloudpickle linter error

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* small change to trigger CI

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* remove all random changes outside rllib that made it into this PR

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* remove random rst

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* fix deprecated is_training call

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* correct in_eval call

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* nit

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* jun's comment

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* use merged config to create connectors

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* Add meaningful assertion error and switch order of if/else block to intuitive order

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* better warning

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* jun's comments

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

* shorter function signature for helper fn

Signed-off-by: Artur Niederfahrenhorst <[email protected]>

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored and WeichenXu123 committed Dec 19, 2022
1 parent e85d7c4 commit e770151
Show file tree
Hide file tree
Showing 10 changed files with 515 additions and 92 deletions.
187 changes: 187 additions & 0 deletions rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Any, List
from gym.spaces import Discrete, MultiDiscrete

import numpy as np
import tree

from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
from ray.rllib.connectors.connector import AgentConnector
from ray.rllib.connectors.connector import (
ConnectorContext,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.filter import RunningStat


@PublicAPI(stability="alpha")
class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector):
"""A connector used to mean-std-filter observations.
Incoming observations are filtered such that the output of this filter is on
average zero and has a standard deviation of 1. This filtering is applied
separately per element of the observation space.
"""

def __init__(
self,
ctx: ConnectorContext,
demean: bool = True,
destd: bool = True,
clip: float = 10.0,
):
SyncedFilterAgentConnector.__init__(self, ctx)
# We simply use the old MeanStdFilter until non-connector env_runner is fully
# deprecated to avoid duplicate code

filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
get_base_struct_from_space(ctx.observation_space),
)
self.filter = MeanStdFilter(filter_shape, demean=demean, destd=destd, clip=clip)

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
if SampleBatch.OBS in d:
d[SampleBatch.OBS] = self.filter(
d[SampleBatch.OBS], update=self._is_training
)
if SampleBatch.NEXT_OBS in d:
d[SampleBatch.NEXT_OBS] = self.filter(
d[SampleBatch.NEXT_OBS], update=self._is_training
)

return ac_data

def to_state(self):
# Flattening is deterministic
flattened_rs = tree.flatten(self.filter.running_stats)
flattened_buffer = tree.flatten(self.filter.buffer)
return MeanStdObservationFilterAgentConnector.__name__, {
"shape": self.filter.shape,
"no_preprocessor": self.filter.no_preprocessor,
"demean": self.filter.demean,
"destd": self.filter.destd,
"clip": self.filter.clip,
"running_stats": [s.to_state() for s in flattened_rs],
"buffer": [s.to_state() for s in flattened_buffer],
}

# demean, destd, clip, and a state dict
@staticmethod
def from_state(
ctx: ConnectorContext,
params: List[Any] = None,
demean: bool = True,
destd: bool = True,
clip: float = 10.0,
):
connector = MeanStdObservationFilterAgentConnector(ctx, demean, destd, clip)
if params:
connector.filter.shape = params["shape"]
connector.filter.no_preprocessor = params["no_preprocessor"]
connector.filter.demean = params["demean"]
connector.filter.destd = params["destd"]
connector.filter.clip = params["clip"]

# Unflattening is deterministic
running_stats = [RunningStat.from_state(s) for s in params["running_stats"]]
connector.filter.running_stats = tree.unflatten_as(
connector.filter.shape, running_stats
)

# Unflattening is deterministic
buffer = [RunningStat.from_state(s) for s in params["buffer"]]
connector.filter.buffer = tree.unflatten_as(connector.filter.shape, buffer)

return connector

def reset_state(self) -> None:
"""Creates copy of current state and resets accumulated state"""
if not self._is_training:
raise ValueError(
"State of {} can only be changed when trainin.".format(self.__name__)
)
self.filter.reset_buffer()

def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
"""Updates self with state from other filter."""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
if not self._is_training:
raise ValueError(
"Changes can only be applied to {} when trainin.".format(self.__name__)
)
return self.filter.apply_changes(other, *args, **kwargs)

def copy(self) -> "Filter":
"""Creates a new object with same state as self.
This is a legacy Filter method that we need to keep around for now
Returns:
A copy of self.
"""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
return self.filter.copy()

def sync(self, other: "AgentConnector") -> None:
"""Copies all state from other filter to self."""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
if not self._is_training:
raise ValueError(
"{} can only be synced when trainin.".format(self.__name__)
)
return self.filter.sync(other.filter)


@PublicAPI(stability="alpha")
class ConcurrentMeanStdObservationFilterAgentConnector(
MeanStdObservationFilterAgentConnector
):
"""A concurrent version of the MeanStdObservationFilterAgentConnector.
This version's filter has all operations wrapped by a threading.RLock.
It can therefore be safely used by multiple threads.
"""

def __init__(self, ctx: ConnectorContext, demean=True, destd=True, clip=10.0):
SyncedFilterAgentConnector.__init__(self, ctx)
# We simply use the old MeanStdFilter until non-connector env_runner is fully
# deprecated to avoid duplicate code

filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
get_base_struct_from_space(ctx.observation_space),
)
self.filter = ConcurrentMeanStdFilter(
filter_shape, demean=True, destd=True, clip=10.0
)


register_connector(
MeanStdObservationFilterAgentConnector.__name__,
MeanStdObservationFilterAgentConnector,
)
register_connector(
ConcurrentMeanStdObservationFilterAgentConnector.__name__,
ConcurrentMeanStdObservationFilterAgentConnector,
)
52 changes: 52 additions & 0 deletions rllib/connectors/agent/synced_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.filter import Filter


@PublicAPI(stability="alpha")
class SyncedFilterAgentConnector(AgentConnector):
"""An agent connector that filters with synchronized parameters."""

def __init__(self, ctx: ConnectorContext, *args, **kwargs):
super().__init__(ctx)
if args or kwargs:
raise ValueError(
"SyncedFilterAgentConnector does not take any additional arguments, "
"but got args=`{}` and kwargs={}.".format(args, kwargs)
)

def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
"""Updates self with state from other filter."""
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
# non-connecto env_runner
return self.filter.apply_changes(other, *args, **kwargs)

def copy(self) -> "Filter":
"""Creates a new object with same state as self.
This is a legacy Filter method that we need to keep around for now
Returns:
A copy of self.
"""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
return self.filter.copy()

def sync(self, other: "AgentConnector") -> None:
"""Copies all state from other filter to self."""
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
# non-connector env_runner
return self.filter.sync(other.filter)

def reset_state(self) -> None:
"""Creates copy of current state and resets accumulated state"""
raise NotImplementedError

def as_serializable(self) -> "Filter":
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
# non-connector env_runner
return self.filter.as_serializable()
87 changes: 87 additions & 0 deletions rllib/connectors/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
import numpy as np
import unittest
from gym.spaces import Box

from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
Expand All @@ -18,6 +19,9 @@
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.rllib.connectors.agent.mean_std_filter import (
MeanStdObservationFilterAgentConnector,
)


class TestAgentConnector(unittest.TestCase):
Expand Down Expand Up @@ -473,6 +477,89 @@ def test_vr_connector_only_keeps_useful_timesteps(self):
# Data matches the latest timestep.
self.assertTrue(np.array_equal(obs_data[0], np.array([4, 5, 6, 7])))

def test_mean_std_observation_filter_connector(self):
for bounds in [
(-1, 1), # normalized
(-2, 2), # scaled
(0, 2), # shifted
(0, 4), # scaled and shifted
]:
print("Testing uniform sampling with bounds: {}".format(bounds))

observation_space = Box(bounds[0], bounds[1], (3, 64, 64))
ctx = ConnectorContext(observation_space=observation_space)
filter_connector = MeanStdObservationFilterAgentConnector(ctx)

# Warm up Mean-Std filter
for i in range(1000):
obs = observation_space.sample()
sample_batch = {
SampleBatch.NEXT_OBS: obs,
}
ac = AgentConnectorDataType(0, 0, sample_batch)
filter_connector.transform(ac)

# Create another connector to set state to
_, state = filter_connector.to_state()
another_filter_connector = (
MeanStdObservationFilterAgentConnector.from_state(ctx, state)
)

another_filter_connector.in_eval()

# Collector transformed observations
transformed_observations = []
for i in range(1000):
obs = observation_space.sample()
sample_batch = {
SampleBatch.NEXT_OBS: obs,
}
ac = AgentConnectorDataType(0, 0, sample_batch)
connector_output = another_filter_connector.transform(ac)
transformed_observations.append(
connector_output.data[SampleBatch.NEXT_OBS]
)

# Check if transformed observations are actually mean-std filtered
self.assertTrue(
np.isclose(np.mean(transformed_observations), 0, atol=0.001)
)
self.assertTrue(np.isclose(np.var(transformed_observations), 1, atol=0.01))

# Check if filter parameters where frozen because we are not training
self.assertTrue(
filter_connector.filter.running_stats.num_pushes
== another_filter_connector.filter.running_stats.num_pushes,
)
self.assertTrue(
np.all(
filter_connector.filter.running_stats.mean_array
== another_filter_connector.filter.running_stats.mean_array,
)
)
self.assertTrue(
np.all(
filter_connector.filter.running_stats.std_array
== another_filter_connector.filter.running_stats.std_array,
)
)
self.assertTrue(
filter_connector.filter.buffer.num_pushes
== another_filter_connector.filter.buffer.num_pushes,
)
self.assertTrue(
np.all(
filter_connector.filter.buffer.mean_array
== another_filter_connector.filter.buffer.mean_array,
)
)
self.assertTrue(
np.all(
filter_connector.filter.buffer.std_array
== another_filter_connector.filter.buffer.std_array,
)
)


if __name__ == "__main__":
import sys
Expand Down
Loading

0 comments on commit e770151

Please sign in to comment.