Skip to content

Commit

Permalink
[RLlib] Quick state buffer connector fix (ray-project#26836)
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
Jun Gong authored and Stefan van der Kleij committed Aug 18, 2022
1 parent 2fa7615 commit 74d9dd4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion rllib/connectors/agent/state_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, ctx: ConnectorContext):

def reset(self, env_id: str):
# If soft horizon, states should be carried over between episodes.
if not self._soft_horizon:
if not self._soft_horizon and env_id in self._states:
del self._states[env_id]

def on_policy_output(self, ac_data: ActionConnectorDataType):
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
data format. E.g., python dict instead of DictSpace, python tuple
instead of TupleSpace.
"""
self.config = config
self.config = config or {}
self.initial_states = model_initial_states or []
self.observation_space = observation_space
self.action_space = action_space
Expand Down
37 changes: 36 additions & 1 deletion rllib/connectors/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType, AgentConnectorsOutput
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentConnectorsOutput,
)


class TestAgentConnector(unittest.TestCase):
Expand Down Expand Up @@ -120,6 +125,36 @@ def test_flatten_data_connector(self):
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")

def test_state_buffer_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)),
)
c = StateBufferConnector(ctx)

# Reset without any buffered data should do nothing.
c.reset(env_id=0)

d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
},
)

with_buffered = c([d])
self.assertEqual(len(with_buffered), 1)
self.assertTrue((with_buffered[0].data[SampleBatch.ACTIONS] == [0, 0, 0]).all())

c.on_policy_output(ActionConnectorDataType(0, 1, ([1, 2, 3], [], {})))

with_buffered = c([d])
self.assertEqual(len(with_buffered), 1)
self.assertEqual(with_buffered[0].data[SampleBatch.ACTIONS], [1, 2, 3])

def test_view_requirement_connector(self):
view_requirements = {
"obs": ViewRequirement(
Expand Down

0 comments on commit 74d9dd4

Please sign in to comment.