From 74d9dd4355438573edecc7b7776ac9351f28bab6 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Thu, 21 Jul 2022 20:43:59 -0700 Subject: [PATCH] [RLlib] Quick state buffer connector fix (#26836) Signed-off-by: Stefan van der Kleij --- rllib/connectors/agent/state_buffer.py | 2 +- rllib/connectors/connector.py | 2 +- rllib/connectors/tests/test_agent.py | 37 +++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/rllib/connectors/agent/state_buffer.py b/rllib/connectors/agent/state_buffer.py index 6d7e0f92d3f0..f387bbca46b1 100644 --- a/rllib/connectors/agent/state_buffer.py +++ b/rllib/connectors/agent/state_buffer.py @@ -27,7 +27,7 @@ def __init__(self, ctx: ConnectorContext): def reset(self, env_id: str): # If soft horizon, states should be carried over between episodes. - if not self._soft_horizon: + if not self._soft_horizon and env_id in self._states: del self._states[env_id] def on_policy_output(self, ac_data: ActionConnectorDataType): diff --git a/rllib/connectors/connector.py b/rllib/connectors/connector.py index cd3178e3fae1..378361d5c332 100644 --- a/rllib/connectors/connector.py +++ b/rllib/connectors/connector.py @@ -52,7 +52,7 @@ def __init__( data format. E.g., python dict instead of DictSpace, python tuple instead of TupleSpace. """ - self.config = config + self.config = config or {} self.initial_states = model_initial_states or [] self.observation_space = observation_space self.action_space = action_space diff --git a/rllib/connectors/tests/test_agent.py b/rllib/connectors/tests/test_agent.py index 61ab638fdc60..9de09ce04eec 100644 --- a/rllib/connectors/tests/test_agent.py +++ b/rllib/connectors/tests/test_agent.py @@ -6,11 +6,16 @@ from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline +from ray.rllib.connectors.agent.state_buffer import StateBufferConnector from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector from ray.rllib.connectors.connector import ConnectorContext, get_connector from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.typing import AgentConnectorDataType, AgentConnectorsOutput +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + AgentConnectorDataType, + AgentConnectorsOutput, +) class TestAgentConnector(unittest.TestCase): @@ -120,6 +125,36 @@ def test_flatten_data_connector(self): self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2) self.assertEqual(batch[SampleBatch.INFOS]["random"], "info") + def test_state_buffer_connector(self): + ctx = ConnectorContext( + action_space=gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)), + ) + c = StateBufferConnector(ctx) + + # Reset without any buffered data should do nothing. + c.reset(env_id=0) + + d = AgentConnectorDataType( + 0, + 1, + { + SampleBatch.NEXT_OBS: { + "sensor1": [[1, 1], [2, 2]], + "sensor2": 8.8, + }, + }, + ) + + with_buffered = c([d]) + self.assertEqual(len(with_buffered), 1) + self.assertTrue((with_buffered[0].data[SampleBatch.ACTIONS] == [0, 0, 0]).all()) + + c.on_policy_output(ActionConnectorDataType(0, 1, ([1, 2, 3], [], {}))) + + with_buffered = c([d]) + self.assertEqual(len(with_buffered), 1) + self.assertEqual(with_buffered[0].data[SampleBatch.ACTIONS], [1, 2, 3]) + def test_view_requirement_connector(self): view_requirements = { "obs": ViewRequirement(