Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add filters to connector pipeline #27864

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5680d70
initial
ArturNiederfahrenhorst Aug 11, 2022
4123ac5
initial
ArturNiederfahrenhorst Aug 11, 2022
c62faf2
lint and comments
ArturNiederfahrenhorst Aug 11, 2022
12d849f
Merge branch 'connectorsstates' into filterstoconnectors
ArturNiederfahrenhorst Aug 12, 2022
f1d4eec
wip
ArturNiederfahrenhorst Aug 12, 2022
1505c28
format
ArturNiederfahrenhorst Aug 13, 2022
af3f6de
implements and uses filters, working for ppo cartpole with meanstd
ArturNiederfahrenhorst Aug 15, 2022
4ddcdeb
merge master
ArturNiederfahrenhorst Aug 26, 2022
6c32319
get rid of synced custom filter abstraction
ArturNiederfahrenhorst Aug 26, 2022
26dbe68
add meanstd filter connetor test and minor fixes
ArturNiederfahrenhorst Aug 26, 2022
481fcdf
jun's comment
ArturNiederfahrenhorst Aug 26, 2022
6119961
move observation space struct logic because information is already in…
ArturNiederfahrenhorst Aug 26, 2022
ab0bba9
Merge branch 'connectorsstates' into filterstoconnectors
ArturNiederfahrenhorst Aug 26, 2022
0c3460d
fix docstrings
ArturNiederfahrenhorst Aug 26, 2022
ec262a2
Merge branch 'connectorsstates' into filterstoconnectors
ArturNiederfahrenhorst Aug 26, 2022
ad0cc88
minor fixes
ArturNiederfahrenhorst Aug 27, 2022
9da99cb
initial
ArturNiederfahrenhorst Aug 27, 2022
0adc85a
merge minor add_policy_fix
ArturNiederfahrenhorst Aug 27, 2022
71f4fe2
fix for config=None in add_policy and connector=None
ArturNiederfahrenhorst Aug 27, 2022
6050961
fix config name
ArturNiederfahrenhorst Aug 29, 2022
6eb34d6
merge master
ArturNiederfahrenhorst Aug 29, 2022
4ae9758
filter connector state is now json serializable
ArturNiederfahrenhorst Aug 30, 2022
cc1c24e
jun's comments
ArturNiederfahrenhorst Aug 30, 2022
3915fd0
initial
ArturNiederfahrenhorst Aug 31, 2022
df36fa3
Merge branch 'connectorsinaddpolicy' into filterstoconnectors
ArturNiederfahrenhorst Aug 31, 2022
54c6736
create connectors only in create_connectors_for_policy
ArturNiederfahrenhorst Aug 31, 2022
1098b11
initial
ArturNiederfahrenhorst Aug 31, 2022
9ff974e
Merge branch 'ConnectorPipelineget' into filterstoconnectors
ArturNiederfahrenhorst Aug 31, 2022
c09a0e3
get filter connector by __get__item
ArturNiederfahrenhorst Aug 31, 2022
41b9bff
Merge branch 'master' into filterstoconnectors
ArturNiederfahrenhorst Sep 1, 2022
cee9ad5
Merge branch 'master' into filterstoconnectors
ArturNiederfahrenhorst Sep 12, 2022
3b76170
merge master
ArturNiederfahrenhorst Sep 12, 2022
56fafa5
remove observation filters
ArturNiederfahrenhorst Sep 19, 2022
4d2cec4
Merge branch 'master' into filterstoconnectors
ArturNiederfahrenhorst Sep 21, 2022
e637f4f
minor fixes
ArturNiederfahrenhorst Sep 21, 2022
86f1754
initial
ArturNiederfahrenhorst Sep 22, 2022
c89790c
revert spelling error
ArturNiederfahrenhorst Sep 22, 2022
06beebc
Merge branch 'make_add_policy_config_explicit' into filterstoconnectors
ArturNiederfahrenhorst Sep 22, 2022
1b7f0bb
initial
ArturNiederfahrenhorst Sep 22, 2022
f2833d3
Revert "Merge branch 'make_add_policy_config_explicit' into filtersto…
ArturNiederfahrenhorst Sep 22, 2022
6abb336
merge master
ArturNiederfahrenhorst Sep 22, 2022
9c294c1
accomodate case in which config={} in add_policy
ArturNiederfahrenhorst Sep 22, 2022
0fd3e09
Merge branch 'fix_ale' into filterstoconnectors
ArturNiederfahrenhorst Sep 22, 2022
0f0461f
fix connectors enabled not no SyncedFilterAgentConnector case
ArturNiederfahrenhorst Sep 23, 2022
e6164a2
Merge branch 'master' into filterstoconnectors
ArturNiederfahrenhorst Sep 23, 2022
45b576e
initial
ArturNiederfahrenhorst Sep 23, 2022
a5eb07d
merge configs in add_policy
ArturNiederfahrenhorst Sep 26, 2022
26dbf8b
format
ArturNiederfahrenhorst Sep 26, 2022
de68a44
merge connectors from own config fix
ArturNiederfahrenhorst Sep 26, 2022
9cb9ebd
merge master
ArturNiederfahrenhorst Sep 28, 2022
555b48d
merge master
ArturNiederfahrenhorst Sep 28, 2022
cf00366
merge from 28739
ArturNiederfahrenhorst Sep 28, 2022
25db0b1
revert random cloudpickle linter error
ArturNiederfahrenhorst Sep 28, 2022
3665588
Merge branch 'create_connectors_from_own_config_in_add_policy' into f…
ArturNiederfahrenhorst Sep 28, 2022
8d5fc3f
small change to trigger CI
ArturNiederfahrenhorst Sep 28, 2022
070cd91
remove all random changes outside rllib that made it into this PR
ArturNiederfahrenhorst Sep 28, 2022
a9ecae8
remove random rst
ArturNiederfahrenhorst Sep 28, 2022
9200021
fix deprecated is_training call
ArturNiederfahrenhorst Sep 28, 2022
aa9c59a
correct in_eval call
ArturNiederfahrenhorst Sep 28, 2022
a644bb4
nit
ArturNiederfahrenhorst Sep 30, 2022
c06a63a
merge master
ArturNiederfahrenhorst Sep 30, 2022
40c2264
jun's comment
ArturNiederfahrenhorst Sep 30, 2022
60ca25d
use merged config to create connectors
ArturNiederfahrenhorst Sep 30, 2022
d166ac4
merge master
ArturNiederfahrenhorst Oct 3, 2022
ead68d3
Add meaningful assertion error and switch order of if/else block to i…
ArturNiederfahrenhorst Oct 3, 2022
6492024
better warning
ArturNiederfahrenhorst Oct 3, 2022
71038d3
jun's comments
ArturNiederfahrenhorst Oct 3, 2022
8e4b8dd
shorter function signature for helper fn
ArturNiederfahrenhorst Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Any, List
from gym.spaces import Discrete, MultiDiscrete

import numpy as np
import tree

from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
from ray.rllib.connectors.connector import AgentConnector
from ray.rllib.connectors.connector import (
ConnectorContext,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.filter import RunningStat


@PublicAPI(stability="alpha")
class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector):
"""A connector used to mean-std-filter observations.

Incoming observations are filtered such that the output of this filter is on
average zero and has a standard deviation of 1. This filtering is applied
separately per element of the observation space.
"""

def __init__(
self,
ctx: ConnectorContext,
demean: bool = True,
destd: bool = True,
clip: float = 10.0,
):
SyncedFilterAgentConnector.__init__(self, ctx)
# We simply use the old MeanStdFilter until non-connector env_runner is fully
# deprecated to avoid duplicate code

filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
get_base_struct_from_space(ctx.observation_space),
)
self.filter = MeanStdFilter(filter_shape, demean=demean, destd=destd, clip=clip)

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
if SampleBatch.OBS in d:
d[SampleBatch.OBS] = self.filter(
d[SampleBatch.OBS], update=self._is_training
)
if SampleBatch.NEXT_OBS in d:
d[SampleBatch.NEXT_OBS] = self.filter(
d[SampleBatch.NEXT_OBS], update=self._is_training
)

return ac_data

def to_state(self):
# Flattening is deterministic
flattened_rs = tree.flatten(self.filter.running_stats)
flattened_buffer = tree.flatten(self.filter.buffer)
return MeanStdObservationFilterAgentConnector.__name__, {
"shape": self.filter.shape,
"no_preprocessor": self.filter.no_preprocessor,
"demean": self.filter.demean,
"destd": self.filter.destd,
"clip": self.filter.clip,
"running_stats": [s.to_state() for s in flattened_rs],
"buffer": [s.to_state() for s in flattened_buffer],
}

# demean, destd, clip, and a state dict
@staticmethod
def from_state(
ctx: ConnectorContext,
params: List[Any] = None,
demean: bool = True,
destd: bool = True,
clip: float = 10.0,
):
connector = MeanStdObservationFilterAgentConnector(ctx, demean, destd, clip)
if params:
connector.filter.shape = params["shape"]
connector.filter.no_preprocessor = params["no_preprocessor"]
connector.filter.demean = params["demean"]
connector.filter.destd = params["destd"]
connector.filter.clip = params["clip"]

# Unflattening is deterministic
running_stats = [RunningStat.from_state(s) for s in params["running_stats"]]
connector.filter.running_stats = tree.unflatten_as(
connector.filter.shape, running_stats
)

# Unflattening is deterministic
buffer = [RunningStat.from_state(s) for s in params["buffer"]]
connector.filter.buffer = tree.unflatten_as(connector.filter.shape, buffer)

return connector

def reset_state(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do we usually need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the reset_state() method is part of the SyncedFilterAgentConnector interface.
Synchronization is done via the old filter synchronization mechanism, which I would like to leave in place until we switch connectors on by default. After that, I would like to simply inline all the filter code and call the connector's reset_state method directly.

Until we have gotten that far, I can delete this method if you like.

"""Creates copy of current state and resets accumulated state"""
if not self._is_training:
raise ValueError(
"State of {} can only be changed when trainin.".format(self.__name__)
)
self.filter.reset_buffer()

def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
"""Updates self with state from other filter."""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
if not self._is_training:
raise ValueError(
"Changes can only be applied to {} when trainin.".format(self.__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, why can't we update during inference as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked @kouroshHakha if you update mean-std-filters during deployment because I was wondering the same thing. He told me that it is best practice to stop after training.

)
return self.filter.apply_changes(other, *args, **kwargs)

def copy(self) -> "Filter":
"""Creates a new object with same state as self.

This is a legacy Filter method that we need to keep around for now

Returns:
A copy of self.
"""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
return self.filter.copy()

def sync(self, other: "AgentConnector") -> None:
"""Copies all state from other filter to self."""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
if not self._is_training:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the discussion in your first comment!

raise ValueError(
"{} can only be synced when trainin.".format(self.__name__)
)
return self.filter.sync(other.filter)


@PublicAPI(stability="alpha")
class ConcurrentMeanStdObservationFilterAgentConnector(
MeanStdObservationFilterAgentConnector
):
"""A concurrent version of the MeanStdObservationFilterAgentConnector.

This version's filter has all operations wrapped by a threading.RLock.
It can therefore be safely used by multiple threads.
"""

def __init__(self, ctx: ConnectorContext, demean=True, destd=True, clip=10.0):
SyncedFilterAgentConnector.__init__(self, ctx)
# We simply use the old MeanStdFilter until non-connector env_runner is fully
# deprecated to avoid duplicate code

filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
get_base_struct_from_space(ctx.observation_space),
)
self.filter = ConcurrentMeanStdFilter(
filter_shape, demean=True, destd=True, clip=10.0
)


register_connector(
MeanStdObservationFilterAgentConnector.__name__,
MeanStdObservationFilterAgentConnector,
)
register_connector(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some comments here describing the difference between these 2 filter connectors?
thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ConcurrentMeanStdObservationFilterAgentConnector.__name__,
ConcurrentMeanStdObservationFilterAgentConnector,
)
52 changes: 52 additions & 0 deletions rllib/connectors/agent/synced_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.filter import Filter


@PublicAPI(stability="alpha")
class SyncedFilterAgentConnector(AgentConnector):
"""An agent connector that filters with synchronized parameters."""

def __init__(self, ctx: ConnectorContext, *args, **kwargs):
super().__init__(ctx)
if args or kwargs:
raise ValueError(
"SyncedFilterAgentConnector does not take any additional arguments, "
"but got args=`{}` and kwargs={}.".format(args, kwargs)
)

def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
"""Updates self with state from other filter."""
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
# non-connecto env_runner
return self.filter.apply_changes(other, *args, **kwargs)

def copy(self) -> "Filter":
"""Creates a new object with same state as self.

This is a legacy Filter method that we need to keep around for now

Returns:
A copy of self.
"""
# inline this as soon as we deprecate ordinary filter with non-connector
# env_runner
return self.filter.copy()

def sync(self, other: "AgentConnector") -> None:
"""Copies all state from other filter to self."""
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
# non-connector env_runner
return self.filter.sync(other.filter)

def reset_state(self) -> None:
"""Creates copy of current state and resets accumulated state"""
raise NotImplementedError

def as_serializable(self) -> "Filter":
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
# non-connector env_runner
return self.filter.as_serializable()
87 changes: 87 additions & 0 deletions rllib/connectors/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
import numpy as np
import unittest
from gym.spaces import Box

from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
Expand All @@ -18,6 +19,9 @@
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.rllib.connectors.agent.mean_std_filter import (
MeanStdObservationFilterAgentConnector,
)


class TestAgentConnector(unittest.TestCase):
Expand Down Expand Up @@ -473,6 +477,89 @@ def test_vr_connector_only_keeps_useful_timesteps(self):
# Data matches the latest timestep.
self.assertTrue(np.array_equal(obs_data[0], np.array([4, 5, 6, 7])))

def test_mean_std_observation_filter_connector(self):
for bounds in [
(-1, 1), # normalized
(-2, 2), # scaled
(0, 2), # shifted
(0, 4), # scaled and shifted
]:
print("Testing uniform sampling with bounds: {}".format(bounds))

observation_space = Box(bounds[0], bounds[1], (3, 64, 64))
ctx = ConnectorContext(observation_space=observation_space)
filter_connector = MeanStdObservationFilterAgentConnector(ctx)

# Warm up Mean-Std filter
for i in range(1000):
obs = observation_space.sample()
sample_batch = {
SampleBatch.NEXT_OBS: obs,
}
ac = AgentConnectorDataType(0, 0, sample_batch)
filter_connector.transform(ac)

# Create another connector to set state to
_, state = filter_connector.to_state()
another_filter_connector = (
MeanStdObservationFilterAgentConnector.from_state(ctx, state)
)

another_filter_connector.in_eval()

# Collector transformed observations
transformed_observations = []
for i in range(1000):
obs = observation_space.sample()
sample_batch = {
SampleBatch.NEXT_OBS: obs,
}
ac = AgentConnectorDataType(0, 0, sample_batch)
connector_output = another_filter_connector.transform(ac)
transformed_observations.append(
connector_output.data[SampleBatch.NEXT_OBS]
)

# Check if transformed observations are actually mean-std filtered
self.assertTrue(
np.isclose(np.mean(transformed_observations), 0, atol=0.001)
)
self.assertTrue(np.isclose(np.var(transformed_observations), 1, atol=0.01))

# Check if filter parameters where frozen because we are not training
self.assertTrue(
filter_connector.filter.running_stats.num_pushes
== another_filter_connector.filter.running_stats.num_pushes,
)
self.assertTrue(
np.all(
filter_connector.filter.running_stats.mean_array
== another_filter_connector.filter.running_stats.mean_array,
)
)
self.assertTrue(
np.all(
filter_connector.filter.running_stats.std_array
== another_filter_connector.filter.running_stats.std_array,
)
)
self.assertTrue(
filter_connector.filter.buffer.num_pushes
== another_filter_connector.filter.buffer.num_pushes,
)
self.assertTrue(
np.all(
filter_connector.filter.buffer.mean_array
== another_filter_connector.filter.buffer.mean_array,
)
)
self.assertTrue(
np.all(
filter_connector.filter.buffer.std_array
== another_filter_connector.filter.buffer.std_array,
)
)


if __name__ == "__main__":
import sys
Expand Down
Loading