From 970275d969d3c7b7146983873ae1495ef3c54646 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 9 Apr 2024 20:58:51 +0200 Subject: [PATCH 1/9] wip Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 72 ++++++++++++++++---------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 07f97aa79927..3ada5116562e 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -226,9 +226,9 @@ def __init__( # a next observation, yet. In this case we buffer the action, add the rewards, # and record `is_terminated/is_truncated` until the next observation is # received. - self._agent_buffered_actions = {} - self._agent_buffered_extra_model_outputs = defaultdict(dict) - self._agent_buffered_rewards = {} + self._hanging_actions = {} + self._hanging_extra_model_outputs = defaultdict(dict) + self._hanging_rewards = {} # If this is an ongoing episode than the last `__all__` should be `False` self.is_terminated: bool = ( @@ -470,7 +470,7 @@ def add_env_step( # collected buffered rewards. # b) The observation is the first observation for this agent ID. elif _observation is not None and _action is None: - _action = self._agent_buffered_actions.pop(agent_id, None) + _action = self._hanging_actions.pop(agent_id, None) # We have a buffered action (the agent had acted after the previous # observation, but the env had not responded - until now - with another @@ -478,10 +478,10 @@ def add_env_step( # ...[buffered action] ... ... -> next obs + (reward)? ... if _action is not None: # Get the extra model output if available. - _extra_model_outputs = self._agent_buffered_extra_model_outputs.pop( + _extra_model_outputs = self._hanging_extra_model_outputs.pop( agent_id, None ) - _reward = self._agent_buffered_rewards.pop(agent_id, 0.0) + _reward + _reward = self._hanging_rewards.pop(agent_id, 0.0) + _reward # _agent_step = len(sa_episode) # First observation for this agent, we have no buffered action. # ... [done]? ... -> [1st obs for agent ID] @@ -528,10 +528,10 @@ def add_env_step( # [previous obs] [action] (to be buffered) ... else: # Buffer action, reward, and extra_model_outputs. - assert agent_id not in self._agent_buffered_actions - self._agent_buffered_actions[agent_id] = _action - self._agent_buffered_rewards[agent_id] = _reward - self._agent_buffered_extra_model_outputs[ + assert agent_id not in self._hanging_actions + self._hanging_actions[agent_id] = _action + self._hanging_rewards[agent_id] = _reward + self._hanging_extra_model_outputs[ agent_id ] = _extra_model_outputs @@ -540,7 +540,7 @@ def add_env_step( # -------------------------------------------------------------------------- # Record reward and terminated/truncated flags. else: - _action = self._agent_buffered_actions.get(agent_id) + _action = self._hanging_actions.get(agent_id) # Agent is done. if _terminated or _truncated: @@ -571,14 +571,14 @@ def add_env_step( # `_action` is already `get` above. We don't need to pop out from # the buffer as it gets wiped out anyway below b/c the agent is # done. - _extra_model_outputs = self._agent_buffered_extra_model_outputs.pop( + _extra_model_outputs = self._hanging_extra_model_outputs.pop( agent_id, None ) - _reward = self._agent_buffered_rewards.pop(agent_id, 0.0) + _reward + _reward = self._hanging_rewards.pop(agent_id, 0.0) + _reward # The agent is still alive, just add current reward to buffer. else: - self._agent_buffered_rewards[agent_id] = ( - self._agent_buffered_rewards.get(agent_id, 0.0) + _reward + self._hanging_rewards[agent_id] = ( + self._hanging_rewards.get(agent_id, 0.0) + _reward ) # If agent is stepping, add timestep to `SingleAgentEpisode`. @@ -799,7 +799,7 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": # If there is data (e.g. actions) in the agents' buffers, we might have to # re-adjust the lookback len further into the past to make sure that these # agents have at least one observation to look back to. - for agent_id, agent_actions in self._agent_buffered_actions.items(): + for agent_id, agent_actions in self._hanging_actions.items(): assert self.env_t_to_agent_t[agent_id].get(-1) == self.SKIP_ENV_TS_TAG for i in range(1, self.env_t_to_agent_t[agent_id].lookback + 1): if ( @@ -852,10 +852,10 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": ) # Copy over the current buffer values. - successor._agent_buffered_actions = copy.deepcopy(self._agent_buffered_actions) - successor._agent_buffered_rewards = self._agent_buffered_rewards.copy() - successor._agent_buffered_extra_model_outputs = copy.deepcopy( - self._agent_buffered_extra_model_outputs + successor._hanging_actions = copy.deepcopy(self._hanging_actions) + successor._hanging_rewards = self._hanging_rewards.copy() + successor._hanging_extra_model_outputs = copy.deepcopy( + self._hanging_extra_model_outputs ) return successor @@ -1347,8 +1347,8 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode": check(a0.observations, [0]) check(a0.actions, []) check(a0.rewards, []) - check(slice._agent_buffered_actions["a0"], 0) - check(slice._agent_buffered_rewards["a0"], 0.1) + check(slice._hanging_actions["a0"], 0) + check(slice._hanging_rewards["a0"], 0.1) Args: slice_: The slice object to use for slicing. This should exclude the @@ -1629,7 +1629,7 @@ def get_return( agent_eps.get_return() for agent_eps in self.agent_episodes.values() ) if consider_buffer: - for buffered_r in self._agent_buffered_rewards.values(): + for buffered_r in self._hanging_rewards.values(): env_return += buffered_r return env_return @@ -1767,13 +1767,13 @@ def _init_single_agent_episodes( # complete step for agent. if len(observations_per_agent[agent_id]) > 1: actions_per_agent[agent_id].append( - self._agent_buffered_actions.pop(agent_id) + self._hanging_actions.pop(agent_id) ) extra_model_outputs_per_agent[agent_id].append( - self._agent_buffered_extra_model_outputs.pop(agent_id) + self._hanging_extra_model_outputs.pop(agent_id) ) rewards_per_agent[agent_id].append( - self._agent_buffered_rewards.pop(agent_id) + self._hanging_rewards.pop(agent_id) ) # First obs for this agent. Make sure the agent's mapping is # appropriately prepended with self.SKIP_ENV_TS_TAG tags. @@ -1787,13 +1787,13 @@ def _init_single_agent_episodes( if agent_id in act: # Always push actions/extra outputs into buffer, then remove them # from there, once the next observation comes in. Same for rewards. - self._agent_buffered_actions[agent_id] = act[agent_id] - self._agent_buffered_extra_model_outputs[agent_id] = extra_outs.get( + self._hanging_actions[agent_id] = act[agent_id] + self._hanging_extra_model_outputs[agent_id] = extra_outs.get( agent_id, {} ) - self._agent_buffered_rewards[ + self._hanging_rewards[ agent_id - ] = self._agent_buffered_rewards.get(agent_id, 0.0) + rew.get( + ] = self._hanging_rewards.get(agent_id, 0.0) + rew.get( agent_id, 0.0 ) # Agent is done (has no action for the next step). @@ -2397,17 +2397,17 @@ def _get_single_agent_data_by_env_step_indices( def _get_buffer_value(self, what: str, agent_id: AgentID) -> Any: """Returns the buffered action/reward/extra_model_outputs for given agent.""" if what == "actions": - return self._agent_buffered_actions.get(agent_id) + return self._hanging_actions.get(agent_id) elif what == "extra_model_outputs": - return self._agent_buffered_extra_model_outputs.get(agent_id) + return self._hanging_extra_model_outputs.get(agent_id) elif what == "rewards": - return self._agent_buffered_rewards.get(agent_id) + return self._hanging_rewards.get(agent_id) def _del_buffers(self, agent_id: AgentID) -> None: """Deletes all action, reward, extra_model_outputs buffers for given agent.""" - self._agent_buffered_actions.pop(agent_id, None) - self._agent_buffered_extra_model_outputs.pop(agent_id, None) - self._agent_buffered_rewards.pop(agent_id, None) + self._hanging_actions.pop(agent_id, None) + self._hanging_extra_model_outputs.pop(agent_id, None) + self._hanging_rewards.pop(agent_id, None) def _del_agent(self, agent_id: AgentID) -> None: """Deletes all data of given agent from this episode.""" From 6606c3b12c08f537adce1871d9d8f70ceb6804f1 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 9 Apr 2024 21:16:21 +0200 Subject: [PATCH 2/9] wip Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 155 ++++++++++---------- rllib/env/tests/test_multi_agent_episode.py | 39 ++--- 2 files changed, 96 insertions(+), 98 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 3ada5116562e..4b5be9f42cf6 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -221,9 +221,9 @@ def __init__( AgentID, InfiniteLookbackBuffer ] = defaultdict(InfiniteLookbackBuffer) - # In the `MultiAgentEpisode` we need these buffers to keep track of actions, + # In the `MultiAgentEpisode` we need these caches to keep track of actions, # that happen when an agent got observations and acted, but did not receive - # a next observation, yet. In this case we buffer the action, add the rewards, + # a next observation, yet. In this case we store the action, add the rewards, # and record `is_terminated/is_truncated` until the next observation is # received. self._hanging_actions = {} @@ -463,19 +463,19 @@ def add_env_step( f"receive any reward from the env!" ) - # CASE 2: Step gets completed with a buffered action OR first observation. + # CASE 2: Step gets completed with a hanging action OR first observation. # ------------------------------------------------------------------------ # We have an observation, but no action -> - # a) Action (and extra model outputs) must be buffered already. Also use - # collected buffered rewards. + # a) Action (and extra model outputs) must be hanging already. Also use + # collected hanging rewards. # b) The observation is the first observation for this agent ID. elif _observation is not None and _action is None: _action = self._hanging_actions.pop(agent_id, None) - # We have a buffered action (the agent had acted after the previous + # We have a hanging action (the agent had acted after the previous # observation, but the env had not responded - until now - with another # observation). - # ...[buffered action] ... ... -> next obs + (reward)? ... + # ...[hanging action] ... ... -> next obs + (reward)? ... if _action is not None: # Get the extra model output if available. _extra_model_outputs = self._hanging_extra_model_outputs.pop( @@ -483,13 +483,13 @@ def add_env_step( ) _reward = self._hanging_rewards.pop(agent_id, 0.0) + _reward # _agent_step = len(sa_episode) - # First observation for this agent, we have no buffered action. + # First observation for this agent, we have no hanging action. # ... [done]? ... -> [1st obs for agent ID] else: # The agent is already done -> The agent thus has never stepped once # and we do not have to create a SingleAgentEpisode for it. if _terminated or _truncated: - self._del_buffers(agent_id) + self._del_hanging(agent_id) continue # This must be the agent's initial observation. else: @@ -503,12 +503,12 @@ def add_env_step( # CASE 3: Step is started (by an action), but not completed (no next obs). # ------------------------------------------------------------------------ - # We have no observation, but we have an action to be buffered (and used - # when we do receive the next obs for this agent in the future). + # We have no observation, but we have a hanging action (used when we receive + # the next obs for this agent in the future). elif agent_id not in observations and agent_id in actions: # Agent got truncated -> Error b/c we would need a last (truncation) # observation for this (otherwise, e.g. bootstrapping would not work). - # [previous obs] [action] (to be buffered) ... ... [truncated] + # [previous obs] [action] (hanging) ... ... [truncated] if _truncated: raise MultiAgentEnvError( f"Agent {agent_id} acted and then got truncated, but did NOT " @@ -516,7 +516,7 @@ def add_env_step( "value function bootstrapping!" ) # Agent got terminated. - # [previous obs] [action] (to be buffered) ... ... [terminated] + # [previous obs] [action] (hanging) ... ... [terminated] elif _terminated: # If the agent was terminated and no observation is provided, # duplicate the previous one (this is a technical "fix" to properly @@ -525,15 +525,13 @@ def add_env_step( _observation = sa_episode.get_observations(-1) _infos = sa_episode.get_infos(-1) # Agent is still alive. - # [previous obs] [action] (to be buffered) ... + # [previous obs] [action] (hanging) ... else: - # Buffer action, reward, and extra_model_outputs. + # Hanging action, reward, and extra_model_outputs. assert agent_id not in self._hanging_actions self._hanging_actions[agent_id] = _action self._hanging_rewards[agent_id] = _reward - self._hanging_extra_model_outputs[ - agent_id - ] = _extra_model_outputs + self._hanging_extra_model_outputs[agent_id] = _extra_model_outputs # CASE 4: Step has started in the past and is still ongoing (no observation, # no action). @@ -548,7 +546,7 @@ def add_env_step( # part of this episode. # ... ... [other agents doing stuff] ... ... [agent done] if _action is None: - self._del_buffers(agent_id) + self._del_hanging(agent_id) continue # Agent got truncated -> Error b/c we would need a last (truncation) @@ -561,7 +559,7 @@ def add_env_step( "for e.g. value function bootstrapping!" ) - # [obs] ... ... [buffered action] ... ... [done] + # [obs] ... ... [hanging action] ... ... [done] # If the agent was terminated and no observation is provided, # duplicate the previous one (this is a technical "fix" to properly # complete the single agent episode; this last observation is never @@ -569,16 +567,16 @@ def add_env_step( _observation = sa_episode.get_observations(-1) _infos = sa_episode.get_infos(-1) # `_action` is already `get` above. We don't need to pop out from - # the buffer as it gets wiped out anyway below b/c the agent is + # the cache as it gets wiped out anyway below b/c the agent is # done. _extra_model_outputs = self._hanging_extra_model_outputs.pop( agent_id, None ) _reward = self._hanging_rewards.pop(agent_id, 0.0) + _reward - # The agent is still alive, just add current reward to buffer. + # The agent is still alive, just add current reward to cache. else: self._hanging_rewards[agent_id] = ( - self._hanging_rewards.get(agent_id, 0.0) + _reward + self._hanging_rewards.get(agent_id, 0.0) + _reward ) # If agent is stepping, add timestep to `SingleAgentEpisode`. @@ -597,10 +595,10 @@ def add_env_step( len(sa_episode) + sa_episode.observations.lookback ) - # Agent is also done. -> Erase all buffered values for this agent + # Agent is also done. -> Erase all hanging values for this agent # (they should be empty at this point anyways). if _terminated or _truncated: - self._del_buffers(agent_id) + self._del_hanging(agent_id) def validate(self) -> None: """Validates the episode's data. @@ -613,7 +611,7 @@ def validate(self) -> None: eps.validate() # TODO (sven): Validate MultiAgentEpisode specifics, like the timestep mappings, - # action/reward buffers, etc.. + # action/reward caches, etc.. @property def is_finalized(self) -> bool: @@ -796,8 +794,8 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": "Can't call `MultiAgentEpisode.cut()` when the episode is already done!" ) - # If there is data (e.g. actions) in the agents' buffers, we might have to - # re-adjust the lookback len further into the past to make sure that these + # If there is hanging data (e.g. actions) in the agents' caches, we might have + # to re-adjust the lookback len further into the past to make sure that these # agents have at least one observation to look back to. for agent_id, agent_actions in self._hanging_actions.items(): assert self.env_t_to_agent_t[agent_id].get(-1) == self.SKIP_ENV_TS_TAG @@ -851,7 +849,7 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": len_lookback_buffer="auto", ) - # Copy over the current buffer values. + # Copy over the current hanging values. successor._hanging_actions = copy.deepcopy(self._hanging_actions) successor._hanging_rewards = self._hanging_rewards.copy() successor._hanging_extra_model_outputs = copy.deepcopy( @@ -1301,7 +1299,8 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode": - In case `slice_` ends - for a certain agent - in an env step, where that particular agent does not have an observation, the previous observation will be included, but the next action and sum of rewards until this point will - be stored in the agent's buffer for the returned MultiAgentEpisode slice. + be stored in the agent's hanging values caches for the returned + MultiAgentEpisode slice. .. testcode:: @@ -1341,7 +1340,7 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode": check((a0.is_done, a1.is_done), (False, False)) # If a slice ends in a "gap" for an agent, expect actions and rewards to be - # cached in the agent's buffer. + # cached for this agent. slice = episode[:2] a0 = slice.agent_episodes["a0"] check(a0.observations, [0]) @@ -1536,7 +1535,7 @@ def get_state(self) -> Dict[str, Any]: Returns: A dicitonary containing pickable data fro a `MultiAgentEpisode`. """ - # TODO (simon): Add the buffers. + # TODO (simon): Add the agent caches. return list( { "id_": self.id_, @@ -1568,7 +1567,7 @@ def from_state(state) -> None: `MultiAgentEpisode` from a state, this state has to be complete, i.e. all data must have been stored in the state. """ - # TODO (simon): Add the buffers. + # TODO (simon): Add the agent caches. episode = MultiAgentEpisode(id=state[0][1]) episode._agent_ids = state[1][1] episode.env_t_to_agent_t = state[2][1] @@ -1610,27 +1609,27 @@ def get_sample_batch(self) -> MultiAgentBatch: def get_return( self, - consider_buffer: bool = False, + consider_hanging_rewards: bool = False, ) -> float: """Returns all-agent return. Args: - consider_buffer: Whether we should also consider - buffered rewards wehn calculating the overall return. Agents might + consider_hanging_rewards: Whether we should also consider + hanging rewards wehn calculating the overall return. Agents might have received partial rewards, i.e. rewards without an - observation. These are stored to the buffer for each agent and added up + observation. These are stored to the cache for each agent and added up until the next observation is received by that agent. Returns: - The sum of all single-agents' returns (maybe including the buffered + The sum of all single-agents' returns (maybe including the hanging rewards per agent). """ env_return = sum( agent_eps.get_return() for agent_eps in self.agent_episodes.values() ) - if consider_buffer: - for buffered_r in self._hanging_rewards.values(): - env_return += buffered_r + if consider_hanging_rewards: + for hanging_r in self._hanging_rewards.values(): + env_return += hanging_r return env_return @@ -1747,7 +1746,7 @@ def _init_single_agent_episodes( for data_idx, (obs, inf) in enumerate(zip(observations, infos)): # If we do have actions/extra outs/rewards for this timestep, use the data. # It may be that these lists have the same length as the observations list, - # in which case the data will be buffered (agent did step/send an action, + # in which case the data will be cached (agent did step/send an action, # but the step has not been concluded yet by the env). act = actions[data_idx] if len(actions) > data_idx else {} extra_outs = ( @@ -1763,7 +1762,7 @@ def _init_single_agent_episodes( observations_per_agent[agent_id].append(agent_obs) infos_per_agent[agent_id].append(inf.get(agent_id, {})) - # Pull out buffered action (if not first obs for this agent) and + # Pull out hanging action (if not first obs for this agent) and # complete step for agent. if len(observations_per_agent[agent_id]) > 1: actions_per_agent[agent_id].append( @@ -1785,17 +1784,15 @@ def _init_single_agent_episodes( # Agent is still continuing (has an action for the next step). if agent_id in act: - # Always push actions/extra outputs into buffer, then remove them + # Always push actions/extra outputs into cache, then remove them # from there, once the next observation comes in. Same for rewards. self._hanging_actions[agent_id] = act[agent_id] self._hanging_extra_model_outputs[agent_id] = extra_outs.get( agent_id, {} ) - self._hanging_rewards[ - agent_id - ] = self._hanging_rewards.get(agent_id, 0.0) + rew.get( + self._hanging_rewards[agent_id] = self._hanging_rewards.get( agent_id, 0.0 - ) + ) + rew.get(agent_id, 0.0) # Agent is done (has no action for the next step). elif terminateds.get(agent_id) or truncateds.get(agent_id): done_per_agent[agent_id] = True @@ -1955,15 +1952,15 @@ def _get_data_by_agent_steps( if agent_id not in agent_ids: continue inf_lookback_buffer = getattr(sa_episode, what) - buffer_val = self._get_buffer_value(what, agent_id) + hanging_val = self._get_hanging_value(what, agent_id) if extra_model_outputs_key is not None: inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key] - buffer_val = buffer_val[extra_model_outputs_key] + hanging_val = hanging_val[extra_model_outputs_key] agent_value = inf_lookback_buffer.get( indices=indices, neg_indices_left_of_zero=neg_indices_left_of_zero, fill=fill, - _add_last_ts_value=buffer_val, + _add_last_ts_value=hanging_val, **one_hot_discrete, ) if agent_value is None or agent_value == []: @@ -2051,7 +2048,7 @@ def _get_data_by_env_steps_as_list( for i in range(len(next(iter(agent_indices.values())))): ret2 = {} for agent_id, idxes in agent_indices.items(): - buffer_val = self._get_buffer_value(what, agent_id) + hanging_val = self._get_hanging_value(what, agent_id) ( inf_lookback_buffer, indices_to_use, @@ -2059,7 +2056,7 @@ def _get_data_by_env_steps_as_list( agent_id, what, extra_model_outputs_key, - buffer_val, + hanging_val, filter_for_skip_indices=idxes[i], ) agent_value = self._get_single_agent_data_by_index( @@ -2069,7 +2066,7 @@ def _get_data_by_env_steps_as_list( index_incl_lookback=indices_to_use, fill=fill, one_hot_discrete=one_hot_discrete, - buffer_val=buffer_val, + hanging_val=hanging_val, extra_model_outputs_key=extra_model_outputs_key, ) if agent_value is not None: @@ -2142,7 +2139,7 @@ def _get_data_by_env_steps( for agent_id, sa_episode in self.agent_episodes.items(): if agent_id not in agent_ids: continue - buffer_val = self._get_buffer_value(what, agent_id) + hanging_val = self._get_hanging_value(what, agent_id) agent_indices = self.env_t_to_agent_t[agent_id].get( indices, neg_indices_left_of_zero=neg_indices_left_of_zero, @@ -2156,7 +2153,7 @@ def _get_data_by_env_steps( agent_id, what, extra_model_outputs_key, - buffer_val, + hanging_val, filter_for_skip_indices=agent_indices, ) if isinstance(agent_indices, list): @@ -2166,7 +2163,7 @@ def _get_data_by_env_steps( indices_incl_lookback=agent_indices, fill=fill, one_hot_discrete=one_hot_discrete, - buffer_val=buffer_val, + hanging_val=hanging_val, extra_model_outputs_key=extra_model_outputs_key, ) if len(agent_values) > 0: @@ -2179,7 +2176,7 @@ def _get_data_by_env_steps( index_incl_lookback=agent_indices, fill=fill, one_hot_discrete=one_hot_discrete, - buffer_val=buffer_val, + hanging_val=hanging_val, extra_model_outputs_key=extra_model_outputs_key, ) if agent_values is not None: @@ -2196,7 +2193,7 @@ def _get_single_agent_data_by_index( fill: Optional[Any] = None, one_hot_discrete: bool = False, extra_model_outputs_key: Optional[str] = None, - buffer_val: Optional[Any] = None, + hanging_val: Optional[Any] = None, ) -> Any: """Returns single data item from the episode based on given (env step) index. @@ -2232,12 +2229,12 @@ def _get_single_agent_data_by_index( extra_model_outputs_key: Only if what is "extra_model_outputs", this specifies the sub-key (str) inside the extra_model_outputs dict, e.g. STATE_OUT or ACTION_DIST_INPUTS. - buffer_val: In case we are pulling actions, rewards, or extra_model_outputs - data, there might be information "in-flight" (buffered). For example, + hanging_val: In case we are pulling actions, rewards, or extra_model_outputs + data, there might be information "hanging" (cached). For example, if an agent receives an observation o0 and then immediately sends an action a0 back, but then does NOT immediately reveive a next - observation, a0 is now buffered (not fully logged yet with this - episode). The currently buffered value must be provided here to be able + observation, a0 is now cached (not fully logged yet with this + episode). The currently cached value must be provided here to be able to return it in case the index is -1 (most recent timestep). Returns: @@ -2274,7 +2271,7 @@ def _get_single_agent_data_by_index( indices=index_incl_lookback - sub_buffer.lookback, neg_indices_left_of_zero=True, fill=fill, - _add_last_ts_value=buffer_val, + _add_last_ts_value=hanging_val, **one_hot_discrete, ) for key, sub_buffer in inf_lookback_buffer.items() @@ -2285,7 +2282,7 @@ def _get_single_agent_data_by_index( indices=index_incl_lookback - inf_lookback_buffer.lookback, neg_indices_left_of_zero=True, fill=fill, - _add_last_ts_value=buffer_val, + _add_last_ts_value=hanging_val, **one_hot_discrete, ) @@ -2298,7 +2295,7 @@ def _get_single_agent_data_by_env_step_indices( fill: Optional[Any] = None, one_hot_discrete: bool = False, extra_model_outputs_key: Optional[str] = None, - buffer_val: Optional[Any] = None, + hanging_val: Optional[Any] = None, ) -> Any: """Returns single data item from the episode based on given (env step) indices. @@ -2335,12 +2332,12 @@ def _get_single_agent_data_by_env_step_indices( extra_model_outputs_key: Only if what is "extra_model_outputs", this specifies the sub-key (str) inside the extra_model_outputs dict, e.g. STATE_OUT or ACTION_DIST_INPUTS. - buffer_val: In case we are pulling actions, rewards, or extra_model_outputs - data, there might be information "in-flight" (buffered). For example, + hanging_val: In case we are pulling actions, rewards, or extra_model_outputs + data, there might be information "hanging" (cached). For example, if an agent receives an observation o0 and then immediately sends an action a0 back, but then does NOT immediately reveive a next - observation, a0 is now buffered (not fully logged yet with this - episode). The currently buffered value must be provided here to be able + observation, a0 is now cached (not fully logged yet with this + episode). The currently cached value must be provided here to be able to return it in case the index is -1 (most recent timestep). Returns: @@ -2372,7 +2369,7 @@ def _get_single_agent_data_by_env_step_indices( indices=i - getattr(sa_episode, what).lookback, neg_indices_left_of_zero=True, fill=fill, - _add_last_ts_value=buffer_val, + _add_last_ts_value=hanging_val, **one_hot_discrete, ) ) @@ -2389,13 +2386,13 @@ def _get_single_agent_data_by_env_step_indices( indices=indices, neg_indices_left_of_zero=True, fill=fill, - _add_last_ts_value=buffer_val, + _add_last_ts_value=hanging_val, **one_hot_discrete, ) return ret - def _get_buffer_value(self, what: str, agent_id: AgentID) -> Any: - """Returns the buffered action/reward/extra_model_outputs for given agent.""" + def _get_hanging_value(self, what: str, agent_id: AgentID) -> Any: + """Returns the hanging action/reward/extra_model_outputs for given agent.""" if what == "actions": return self._hanging_actions.get(agent_id) elif what == "extra_model_outputs": @@ -2403,15 +2400,15 @@ def _get_buffer_value(self, what: str, agent_id: AgentID) -> Any: elif what == "rewards": return self._hanging_rewards.get(agent_id) - def _del_buffers(self, agent_id: AgentID) -> None: - """Deletes all action, reward, extra_model_outputs buffers for given agent.""" + def _del_hanging(self, agent_id: AgentID) -> None: + """Deletes all hanging action, reward, extra_model_outputs of given agent.""" self._hanging_actions.pop(agent_id, None) self._hanging_extra_model_outputs.pop(agent_id, None) self._hanging_rewards.pop(agent_id, None) def _del_agent(self, agent_id: AgentID) -> None: """Deletes all data of given agent from this episode.""" - self._del_buffers(agent_id) + self._del_hanging(agent_id) self.agent_episodes.pop(agent_id, None) self.agent_ids.discard(agent_id) self.env_t_to_agent_t.pop(agent_id, None) @@ -2423,7 +2420,7 @@ def _get_inf_lookback_buffer_or_dict( agent_id: AgentID, what: str, extra_model_outputs_key: Optional[str] = None, - buffer_val: Optional[Any] = None, + hanging_val: Optional[Any] = None, filter_for_skip_indices=None, ): """Returns a single InfiniteLookbackBuffer or a dict of such. @@ -2451,7 +2448,7 @@ def _get_inf_lookback_buffer_or_dict( inf_lookback_buffer_len = ( len(inf_lookback_buffer) + inf_lookback_buffer.lookback - + (buffer_val is not None) + + (hanging_val is not None) ) ignore_last_ts = what not in ["observations", "infos"] if isinstance(filter_for_skip_indices, list): diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index e46caba48f34..ca259c0a2e8a 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -165,9 +165,9 @@ def test_init(self): check(episode.agent_episodes["a1"].actions.data, [0, 1, 2]) check(episode.agent_episodes["a0"].rewards.data, []) check(episode.agent_episodes["a1"].rewards.data, [0.1, 0.2, 0.3]) - check(episode._agent_buffered_actions, {"a0": 0}) - check(episode._agent_buffered_rewards, {"a0": 0.1}) - check(episode._agent_buffered_extra_model_outputs, {"a0": {}}) + check(episode._hanging_actions, {"a0": 0}) + check(episode._hanging_rewards, {"a0": 0.1}) + check(episode._hanging_extra_model_outputs, {"a0": {}}) check(episode.env_t_to_agent_t["a0"].data, [0, "S", "S", "S"]) check(episode.env_t_to_agent_t["a1"].data, [0, 1, 2, 3]) check(episode.env_t_to_agent_t["a0"].lookback, 3) @@ -186,9 +186,9 @@ def test_init(self): check(episode.agent_episodes["a1"].actions.data, [1, 2, 3]) check(episode.agent_episodes["a0"].rewards.data, [0.1]) check(episode.agent_episodes["a1"].rewards.data, [0.2, 0.3, 0.4]) - check(episode._agent_buffered_actions, {"a0": 2}) - check(episode._agent_buffered_rewards, {"a0": 0.3}) - check(episode._agent_buffered_extra_model_outputs, {"a0": {}}) + check(episode._hanging_actions, {"a0": 2}) + check(episode._hanging_rewards, {"a0": 0.3}) + check(episode._hanging_extra_model_outputs, {"a0": {}}) check(episode.env_t_to_agent_t["a0"].data, [0, "S", 1, "S", "S"]) check(episode.env_t_to_agent_t["a1"].data, ["S", 0, 1, 2, 3]) check(episode.env_t_to_agent_t["a0"].lookback, 4) @@ -433,9 +433,9 @@ def test_add_env_step(self): self.assertTrue(episode.agent_episodes["agent_5"].is_done) # Also ensure that their buffers are all empty: for agent_id in ["agent_1", "agent_5"]: - self.assertTrue(agent_id not in episode._agent_buffered_actions) - self.assertTrue(agent_id not in episode._agent_buffered_rewards) - self.assertTrue(agent_id not in episode._agent_buffered_extra_model_outputs) + self.assertTrue(agent_id not in episode._hanging_actions) + self.assertTrue(agent_id not in episode._hanging_rewards) + self.assertTrue(agent_id not in episode._hanging_extra_model_outputs) # Check validity of agent_0's env_t_to_agent_t mapping. check(episode.env_t_to_agent_t["agent_0"].data, agent_0_steps) @@ -511,10 +511,10 @@ def test_add_env_step(self): # Assert that the action buffer for agent 4 is full. # Note, agent 4 acts, but receives no observation. # Note also, all other buffers are always full, due to their defaults. - self.assertTrue(episode._agent_buffered_actions["agent_4"] is not None) + self.assertTrue(episode._hanging_actions["agent_4"] is not None) # Assert that the reward buffers of agents 3 and 5 are at 1.0. - check(episode._agent_buffered_rewards["agent_3"], 2.2) - check(episode._agent_buffered_rewards["agent_5"], 1.0) + check(episode._hanging_rewards["agent_3"], 2.2) + check(episode._hanging_rewards["agent_5"], 1.0) def test_get_observations(self): # Generate simple records for a multi agent environment. @@ -2277,7 +2277,7 @@ def test_cut(self): check(episode_2.agent_episodes["a0"].observations.lookback, 0) # Action was "logged" -> Buffer should now be completely empty. check(episode_2.agent_episodes["a0"].actions.data, [0]) - check(episode_2._agent_buffered_actions, {}) + check(episode_2._hanging_actions, {}) check(episode_2.agent_episodes["a0"].actions.lookback, 0) check(episode_2.get_observations(-1), {"a0": 1, "a1": 4}) check(episode_2.get_observations(-1, env_steps=False), {"a0": 1, "a1": 4}) @@ -2640,8 +2640,8 @@ def test_slice(self): check((a0.actions, a1.actions), ([], [0])) check((a0.rewards, a1.rewards), ([], [0.1])) check((a0.is_done, a1.is_done), (False, False)) - check(slice_._agent_buffered_actions["a0"], 0) - check(slice_._agent_buffered_rewards["a0"], 0.1) + check(slice_._hanging_actions["a0"], 0) + check(slice_._hanging_rewards["a0"], 0.1) # To pos stop. slice_ = episode[:3] check(len(slice_), 3) @@ -2669,7 +2669,7 @@ def test_slice(self): ) check((a0.is_done, a1.is_done), (False, False)) # Expect the hanging action to be found in the buffer. - check(slice_._agent_buffered_actions["a0"], 6) + check(slice_._hanging_actions["a0"], 6) slice_ = episode[:-4] check(len(slice_), 5) @@ -3004,9 +3004,10 @@ def test_get_return(self): # Assert that adding the buffered rewards to the agent returns # gives the expected result when considering the buffer in # `get_return()`. - buffered_rewards = sum(episode._agent_buffered_rewards.values()) + buffered_rewards = sum(episode._hanging_rewards.values()) self.assertTrue( - episode.get_return(consider_buffer=True), agent_returns + buffered_rewards + episode.get_return(consider_hanging_rewards=True), + agent_returns + buffered_rewards, ) def test_len(self): @@ -3165,7 +3166,7 @@ def _mock_multi_agent_records_from_env( obs = { agent_id: agent_obs for agent_id, agent_obs in episode.get_observations().items() - if episode._agent_buffered_actions[agent_id] + if episode._hanging_actions[agent_id] } # Sample `size` many records. From ba98ade8765f8490599676d5f0ae8377f8316a5e Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 11 Apr 2024 12:22:05 +0200 Subject: [PATCH 3/9] wip Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 0a59a7a4c58e..6d3912592902 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -221,14 +221,21 @@ def __init__( AgentID, InfiniteLookbackBuffer ] = defaultdict(InfiniteLookbackBuffer) - # In the `MultiAgentEpisode` we need these caches to keep track of actions, - # that happen when an agent got observations and acted, but did not receive - # a next observation, yet. In this case we store the action, add the rewards, - # and record `is_terminated/is_truncated` until the next observation is - # received. - self._hanging_actions = {} - self._hanging_extra_model_outputs = defaultdict(dict) - self._hanging_rewards = {} + # Create caches for hanging actions/rewards/extra_model_outputs. + # When an agent gets an observation (and then sends an action), but does not + # receive immediately a next observation, we store the "hanging" action (and + # related rewards and extra model outputs) in the caches postfixed w/ `_end` + # until the next observation is received. + self._hanging_actions_end = {} + self._hanging_extra_model_outputs_end = defaultdict(dict) + self._hanging_rewards_end = defaultdict(float) + + # In case of a `cut()` or `slice()`, we also need to store the hanging actions, + # rewards, and extra model outputs that were already "hanging" in preceeding + # episode slice. + self._hanging_actions_begin = {} + self._hanging_extra_model_outputs_begin = defaultdict(dict) + self._hanging_rewards_begin = defaultdict(float) # If this is an ongoing episode than the last `__all__` should be `False` self.is_terminated: bool = ( From d7910772697a0063597756c7ad16cd747df8ec88 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 11 Apr 2024 12:29:28 +0200 Subject: [PATCH 4/9] wip Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 66 ++++++++++----------- rllib/env/tests/test_multi_agent_episode.py | 36 +++++------ 2 files changed, 50 insertions(+), 52 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 6d3912592902..90b32d61b017 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -477,7 +477,7 @@ def add_env_step( # collected hanging rewards. # b) The observation is the first observation for this agent ID. elif _observation is not None and _action is None: - _action = self._hanging_actions.pop(agent_id, None) + _action = self._hanging_actions_end.pop(agent_id, None) # We have a hanging action (the agent had acted after the previous # observation, but the env had not responded - until now - with another @@ -485,10 +485,10 @@ def add_env_step( # ...[hanging action] ... ... -> next obs + (reward)? ... if _action is not None: # Get the extra model output if available. - _extra_model_outputs = self._hanging_extra_model_outputs.pop( + _extra_model_outputs = self._hanging_extra_model_outputs_end.pop( agent_id, None ) - _reward = self._hanging_rewards.pop(agent_id, 0.0) + _reward + _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward # _agent_step = len(sa_episode) # First observation for this agent, we have no hanging action. # ... [done]? ... -> [1st obs for agent ID] @@ -535,17 +535,17 @@ def add_env_step( # [previous obs] [action] (hanging) ... else: # Hanging action, reward, and extra_model_outputs. - assert agent_id not in self._hanging_actions - self._hanging_actions[agent_id] = _action - self._hanging_rewards[agent_id] = _reward - self._hanging_extra_model_outputs[agent_id] = _extra_model_outputs + assert agent_id not in self._hanging_actions_end + self._hanging_actions_end[agent_id] = _action + self._hanging_rewards_end[agent_id] = _reward + self._hanging_extra_model_outputs_end[agent_id] = _extra_model_outputs # CASE 4: Step has started in the past and is still ongoing (no observation, # no action). # -------------------------------------------------------------------------- # Record reward and terminated/truncated flags. else: - _action = self._hanging_actions.get(agent_id) + _action = self._hanging_actions_end.get(agent_id) # Agent is done. if _terminated or _truncated: @@ -576,14 +576,14 @@ def add_env_step( # `_action` is already `get` above. We don't need to pop out from # the cache as it gets wiped out anyway below b/c the agent is # done. - _extra_model_outputs = self._hanging_extra_model_outputs.pop( + _extra_model_outputs = self._hanging_extra_model_outputs_end.pop( agent_id, None ) - _reward = self._hanging_rewards.pop(agent_id, 0.0) + _reward + _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward # The agent is still alive, just add current reward to cache. else: - self._hanging_rewards[agent_id] = ( - self._hanging_rewards.get(agent_id, 0.0) + _reward + self._hanging_rewards_end[agent_id] = ( + self._hanging_rewards_end.get(agent_id, 0.0) + _reward ) # If agent is stepping, add timestep to `SingleAgentEpisode`. @@ -804,7 +804,7 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": # If there is hanging data (e.g. actions) in the agents' caches, we might have # to re-adjust the lookback len further into the past to make sure that these # agents have at least one observation to look back to. - for agent_id, agent_actions in self._hanging_actions.items(): + for agent_id, agent_actions in self._hanging_actions_end.items(): assert self.env_t_to_agent_t[agent_id].get(-1) == self.SKIP_ENV_TS_TAG for i in range(1, self.env_t_to_agent_t[agent_id].lookback + 1): if ( @@ -857,10 +857,10 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": ) # Copy over the current hanging values. - successor._hanging_actions = copy.deepcopy(self._hanging_actions) - successor._hanging_rewards = self._hanging_rewards.copy() - successor._hanging_extra_model_outputs = copy.deepcopy( - self._hanging_extra_model_outputs + successor._hanging_actions_end = copy.deepcopy(self._hanging_actions_end) + successor._hanging_rewards_end = self._hanging_rewards_end.copy() + successor._hanging_extra_model_outputs_end = copy.deepcopy( + self._hanging_extra_model_outputs_end ) return successor @@ -1353,8 +1353,8 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode": check(a0.observations, [0]) check(a0.actions, []) check(a0.rewards, []) - check(slice._hanging_actions["a0"], 0) - check(slice._hanging_rewards["a0"], 0.1) + check(slice._hanging_actions_end["a0"], 0) + check(slice._hanging_rewards_end["a0"], 0.1) Args: slice_: The slice object to use for slicing. This should exclude the @@ -1646,7 +1646,7 @@ def get_return( agent_eps.get_return() for agent_eps in self.agent_episodes.values() ) if consider_hanging_rewards: - for hanging_r in self._hanging_rewards.values(): + for hanging_r in self._hanging_rewards_end.values(): env_return += hanging_r return env_return @@ -1784,13 +1784,13 @@ def _init_single_agent_episodes( # complete step for agent. if len(observations_per_agent[agent_id]) > 1: actions_per_agent[agent_id].append( - self._hanging_actions.pop(agent_id) + self._hanging_actions_end.pop(agent_id) ) extra_model_outputs_per_agent[agent_id].append( - self._hanging_extra_model_outputs.pop(agent_id) + self._hanging_extra_model_outputs_end.pop(agent_id) ) rewards_per_agent[agent_id].append( - self._hanging_rewards.pop(agent_id) + self._hanging_rewards_end.pop(agent_id) ) # First obs for this agent. Make sure the agent's mapping is # appropriately prepended with self.SKIP_ENV_TS_TAG tags. @@ -1804,13 +1804,11 @@ def _init_single_agent_episodes( if agent_id in act: # Always push actions/extra outputs into cache, then remove them # from there, once the next observation comes in. Same for rewards. - self._hanging_actions[agent_id] = act[agent_id] - self._hanging_extra_model_outputs[agent_id] = extra_outs.get( + self._hanging_actions_end[agent_id] = act[agent_id] + self._hanging_extra_model_outputs_end[agent_id] = extra_outs.get( agent_id, {} ) - self._hanging_rewards[agent_id] = self._hanging_rewards.get( - agent_id, 0.0 - ) + rew.get(agent_id, 0.0) + self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0) # Agent is done (has no action for the next step). elif terminateds.get(agent_id) or truncateds.get(agent_id): done_per_agent[agent_id] = True @@ -2414,17 +2412,17 @@ def _get_single_agent_data_by_env_step_indices( def _get_hanging_value(self, what: str, agent_id: AgentID) -> Any: """Returns the hanging action/reward/extra_model_outputs for given agent.""" if what == "actions": - return self._hanging_actions.get(agent_id) + return self._hanging_actions_end.get(agent_id) elif what == "extra_model_outputs": - return self._hanging_extra_model_outputs.get(agent_id) + return self._hanging_extra_model_outputs_end.get(agent_id) elif what == "rewards": - return self._hanging_rewards.get(agent_id) + return self._hanging_rewards_end.get(agent_id) def _del_hanging(self, agent_id: AgentID) -> None: """Deletes all hanging action, reward, extra_model_outputs of given agent.""" - self._hanging_actions.pop(agent_id, None) - self._hanging_extra_model_outputs.pop(agent_id, None) - self._hanging_rewards.pop(agent_id, None) + self._hanging_actions_end.pop(agent_id, None) + self._hanging_extra_model_outputs_end.pop(agent_id, None) + self._hanging_rewards_end.pop(agent_id, None) def _del_agent(self, agent_id: AgentID) -> None: """Deletes all data of given agent from this episode.""" diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index 9e53378fcb88..c1022a3320da 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -165,9 +165,9 @@ def test_init(self): check(episode.agent_episodes["a1"].actions.data, [0, 1, 2]) check(episode.agent_episodes["a0"].rewards.data, []) check(episode.agent_episodes["a1"].rewards.data, [0.1, 0.2, 0.3]) - check(episode._hanging_actions, {"a0": 0}) - check(episode._hanging_rewards, {"a0": 0.1}) - check(episode._hanging_extra_model_outputs, {"a0": {}}) + check(episode._hanging_actions_end, {"a0": 0}) + check(episode._hanging_rewards_end, {"a0": 0.1}) + check(episode._hanging_extra_model_outputs_end, {"a0": {}}) check(episode.env_t_to_agent_t["a0"].data, [0, "S", "S", "S"]) check(episode.env_t_to_agent_t["a1"].data, [0, 1, 2, 3]) check(episode.env_t_to_agent_t["a0"].lookback, 3) @@ -186,9 +186,9 @@ def test_init(self): check(episode.agent_episodes["a1"].actions.data, [1, 2, 3]) check(episode.agent_episodes["a0"].rewards.data, [0.1]) check(episode.agent_episodes["a1"].rewards.data, [0.2, 0.3, 0.4]) - check(episode._hanging_actions, {"a0": 2}) - check(episode._hanging_rewards, {"a0": 0.3}) - check(episode._hanging_extra_model_outputs, {"a0": {}}) + check(episode._hanging_actions_end, {"a0": 2}) + check(episode._hanging_rewards_end, {"a0": 0.3}) + check(episode._hanging_extra_model_outputs_end, {"a0": {}}) check(episode.env_t_to_agent_t["a0"].data, [0, "S", 1, "S", "S"]) check(episode.env_t_to_agent_t["a1"].data, ["S", 0, 1, 2, 3]) check(episode.env_t_to_agent_t["a0"].lookback, 4) @@ -433,9 +433,9 @@ def test_add_env_step(self): self.assertTrue(episode.agent_episodes["agent_5"].is_done) # Also ensure that their buffers are all empty: for agent_id in ["agent_1", "agent_5"]: - self.assertTrue(agent_id not in episode._hanging_actions) - self.assertTrue(agent_id not in episode._hanging_rewards) - self.assertTrue(agent_id not in episode._hanging_extra_model_outputs) + self.assertTrue(agent_id not in episode._hanging_actions_end) + self.assertTrue(agent_id not in episode._hanging_rewards_end) + self.assertTrue(agent_id not in episode._hanging_extra_model_outputs_end) # Check validity of agent_0's env_t_to_agent_t mapping. check(episode.env_t_to_agent_t["agent_0"].data, agent_0_steps) @@ -511,10 +511,10 @@ def test_add_env_step(self): # Assert that the action buffer for agent 4 is full. # Note, agent 4 acts, but receives no observation. # Note also, all other buffers are always full, due to their defaults. - self.assertTrue(episode._hanging_actions["agent_4"] is not None) + self.assertTrue(episode._hanging_actions_end["agent_4"] is not None) # Assert that the reward buffers of agents 3 and 5 are at 1.0. - check(episode._hanging_rewards["agent_3"], 2.2) - check(episode._hanging_rewards["agent_5"], 1.0) + check(episode._hanging_rewards_end["agent_3"], 2.2) + check(episode._hanging_rewards_end["agent_5"], 1.0) def test_get_observations(self): # Generate simple records for a multi agent environment. @@ -2277,7 +2277,7 @@ def test_cut(self): check(episode_2.agent_episodes["a0"].observations.lookback, 0) # Action was "logged" -> Buffer should now be completely empty. check(episode_2.agent_episodes["a0"].actions.data, [0]) - check(episode_2._hanging_actions, {}) + check(episode_2._hanging_actions_end, {}) check(episode_2.agent_episodes["a0"].actions.lookback, 0) check(episode_2.get_observations(-1), {"a0": 1, "a1": 4}) check(episode_2.get_observations(-1, env_steps=False), {"a0": 1, "a1": 4}) @@ -2648,8 +2648,8 @@ def test_slice(self): check((a0.actions, a1.actions), ([], [0])) check((a0.rewards, a1.rewards), ([], [0.0])) check((a0.is_done, a1.is_done), (False, False)) - check(slice_._hanging_actions["a0"], 0) - check(slice_._hanging_rewards["a0"], 0.0) + check(slice_._hanging_actions_end["a0"], 0) + check(slice_._hanging_rewards_end["a0"], 0.0) # To pos stop. slice_ = episode[:3] check(len(slice_), 3) @@ -2683,7 +2683,7 @@ def test_slice(self): ) check((a0.is_done, a1.is_done), (False, False)) # Expect the hanging action to be found in the buffer. - check(slice_._hanging_actions["a0"], 6) + check(slice_._hanging_actions_end["a0"], 6) slice_ = episode[:-4] check(len(slice_), 5) @@ -3124,7 +3124,7 @@ def test_get_return(self): # Assert that adding the buffered rewards to the agent returns # gives the expected result when considering the buffer in # `get_return()`. - buffered_rewards = sum(episode._hanging_rewards.values()) + buffered_rewards = sum(episode._hanging_rewards_end.values()) self.assertTrue( episode.get_return(consider_hanging_rewards=True), agent_returns + buffered_rewards, @@ -3294,7 +3294,7 @@ def _mock_multi_agent_records_from_env( obs = { agent_id: agent_obs for agent_id, agent_obs in episode.get_observations().items() - if episode._hanging_actions[agent_id] + if episode._hanging_actions_end[agent_id] } # Sample `size` many records. From 7f18cb184827bc14da9360676dd0aa32666f6864 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 11 Apr 2024 12:29:44 +0200 Subject: [PATCH 5/9] LINT Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 90b32d61b017..7dc4b9712397 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -538,7 +538,9 @@ def add_env_step( assert agent_id not in self._hanging_actions_end self._hanging_actions_end[agent_id] = _action self._hanging_rewards_end[agent_id] = _reward - self._hanging_extra_model_outputs_end[agent_id] = _extra_model_outputs + self._hanging_extra_model_outputs_end[ + agent_id + ] = _extra_model_outputs # CASE 4: Step has started in the past and is still ongoing (no observation, # no action). From 64e53d76f8d34edd155fb01bc8a680a2e980559d Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 11 Apr 2024 13:50:39 +0200 Subject: [PATCH 6/9] wip Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 30 +++--- rllib/env/tests/test_multi_agent_episode.py | 107 ++++++++++++++++++-- 2 files changed, 118 insertions(+), 19 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 7dc4b9712397..e4cadd5bf555 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -829,13 +829,6 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": successor = MultiAgentEpisode( # Same ID. id_=self.id_, - # Same agent IDs. - # Same single agents' episode IDs. - agent_episode_ids=self.agent_episode_ids, - agent_module_ids={ - aid: self.agent_episodes[aid].module_id for aid in self.agent_ids - }, - agent_to_module_mapping_fn=self.agent_to_module_mapping_fn, observations=self.get_observations( indices=indices_obs_and_infos, return_list=True ), @@ -853,15 +846,28 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": ), terminateds=self.get_terminateds(), truncateds=self.get_truncateds(), - # Continue with `self`'s current timestep. + # Continue with `self`'s current timesteps. env_t_started=self.env_t, + agent_t_started={ + aid: self.agent_episodes[aid].t + for aid in self.agent_ids if not self.agent_episodes[aid].is_done + }, + # Same AgentIDs and SingleAgentEpisode IDs. + agent_episode_ids=self.agent_episode_ids, + agent_module_ids={ + aid: self.agent_episodes[aid].module_id for aid in self.agent_ids + }, + agent_to_module_mapping_fn=self.agent_to_module_mapping_fn, + + # All data we provided to the c'tor goes into the lookback buffer. len_lookback_buffer="auto", ) - # Copy over the current hanging values. - successor._hanging_actions_end = copy.deepcopy(self._hanging_actions_end) - successor._hanging_rewards_end = self._hanging_rewards_end.copy() - successor._hanging_extra_model_outputs_end = copy.deepcopy( + # Copy over the hanging (end) values into the hanging (begin) chaches of the + # successor. + successor._hanging_actions_begin = copy.deepcopy(self._hanging_actions_end) + successor._hanging_rewards_begin = self._hanging_rewards_end.copy() + successor._hanging_extra_model_outputs_begin = copy.deepcopy( self._hanging_extra_model_outputs_end ) diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index c1022a3320da..6bec9d9bf25b 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -2210,15 +2210,108 @@ def test_other_getters(self): # --- is_terminated, is_truncated --- def test_cut(self): + # Simple multi-agent episode, in which all agents always step. + episode = self._create_simple_episode([ + {"a0": 0, "a1": 0}, + {"a0": 1, "a1": 1}, + {"a0": 2, "a1": 2}, + ]) + successor = episode.cut() + check(len(successor), 0) + check(successor.env_t_started, 2) + check(successor.env_t, 2) + check(successor.env_t_to_agent_t, {"a0": [0], "a1": [0]}) + a0 = successor.agent_episodes["a0"] + a1 = successor.agent_episodes["a1"] + check((len(a0), len(a1)), (0, 0)) + check((a0.t_started, a1.t_started), (2, 2)) + check((a0.t, a1.t), (2, 2)) + check((a0.observations, a1.observations), ([2], [2])) + check((a0.actions, a1.actions), ([], [])) + check((a0.rewards, a1.rewards), ([], [])) + check(successor._hanging_actions_end, {}) + check(successor._hanging_rewards_end, {}) + check(successor._hanging_extra_model_outputs_end, {}) + + # Multi-agent episode with lookback buffer, in which all agents always step. + episode = self._create_simple_episode([ + {"a0": 0, "a1": 0}, + {"a0": 1, "a1": 1}, + {"a0": 2, "a1": 2}, + {"a0": 3, "a1": 3}, + ], len_lookback_buffer=2) + # Cut with lookback=0 argument (default). + successor = episode.cut() + check(len(successor), 0) + check(successor.env_t_started, 1) + check(successor.env_t, 1) + check(successor.env_t_to_agent_t, {"a0": [0], "a1": [0]}) + a0 = successor.agent_episodes["a0"] + a1 = successor.agent_episodes["a1"] + check((len(a0), len(a1)), (0, 0)) + check((a0.t_started, a1.t_started), (1, 1)) + check((a0.t, a1.t), (1, 1)) + check((a0.observations, a1.observations), ([3], [3])) + check((a0.actions, a1.actions), ([], [])) + check((a0.rewards, a1.rewards), ([], [])) + check(successor._hanging_actions_end, {}) + check(successor._hanging_rewards_end, {}) + check(successor._hanging_extra_model_outputs_end, {}) + # Cut with lookback=2 argument. + successor = episode.cut(len_lookback_buffer=2) + check(len(successor), 0) + check(successor.env_t_started, 1) + check(successor.env_t, 1) + check(successor.env_t_to_agent_t["a0"].data, [0, 1, 2]) + check(successor.env_t_to_agent_t["a1"].data, [0, 1, 2]) + check(successor.env_t_to_agent_t["a0"].lookback, 2) + check(successor.env_t_to_agent_t["a1"].lookback, 2) + a0 = successor.agent_episodes["a0"] + a1 = successor.agent_episodes["a1"] + check((len(a0), len(a1)), (0, 0)) + check((a0.t_started, a1.t_started), (1, 1)) + check((a0.t, a1.t), (1, 1)) + check((a0.observations, a1.observations), ([3], [3])) + check((a0.actions, a1.actions), ([], [])) + check((a0.rewards, a1.rewards), ([], [])) + check(successor._hanging_actions_end, {}) + check(successor._hanging_rewards_end, {}) + check(successor._hanging_extra_model_outputs_end, {}) + + # Multi-agent episode, in which one agent has a long sequence of not acting. + episode = self._create_simple_episode([ + {"a0": 0, "a1": 0}, # 0 + {"a0": 1}, # 1 + {"a0": 2}, # 2 + {"a0": 3}, # 3 + ]) + successor = episode.cut() + check(len(successor), 0) + check(successor.env_t_started, 3) + check(successor.env_t, 3) + a0 = successor.agent_episodes["a0"] + self.assertTrue("a1" not in successor.agent_episodes) + check(len(a0), 0) + check(a0.t_started, 3) + check(a0.t, 3) + check(a0.observations, [3]) + check(a0.actions, []) + check(a0.rewards, []) + check(successor._hanging_actions_begin, {"a1": 0}) + check(successor._hanging_rewards_begin, {"a1": 0.0}) + check(successor._hanging_extra_model_outputs_begin, {"a1": {}}) + check(successor._hanging_actions_end, {}) + check(successor._hanging_rewards_end, {}) + check(successor._hanging_extra_model_outputs_end, {}) + # Generate a simple multi-agent episode and check all internals after # construction. - observations = [{"a0": 0, "a1": 0}, {"a1": 1}, {"a1": 2}, {"a1": 3}] - actions = [{"a0": 0, "a1": 0}, {"a1": 1}, {"a1": 2}] - rewards = [{"a0": 0.1, "a1": 0.1}, {"a1": 0.2}, {"a1": 0.3}] - episode_1 = MultiAgentEpisode( - observations=observations, actions=actions, rewards=rewards - ) - + episode_1 = self._create_simple_episode([ + {"a0": 0, "a1": 0}, + {"a1": 1}, + {"a1": 2}, + {"a1": 3}, + ], len_lookback_buffer="auto") episode_2 = episode_1.cut() check(episode_1.id_, episode_2.id_) check(len(episode_1), 0) From e1ef1114dac7d359eac3789208952cab35f976fb Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 11 Apr 2024 16:22:43 +0200 Subject: [PATCH 7/9] merge Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 18 ++-- rllib/env/tests/test_multi_agent_episode.py | 92 +++++++++++++++------ 2 files changed, 79 insertions(+), 31 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 1959620c5cc7..8b50ce290475 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -584,9 +584,7 @@ def add_env_step( _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward # The agent is still alive, just add current reward to cache. else: - self._hanging_rewards_end[agent_id] = ( - self._hanging_rewards_end.get(agent_id, 0.0) + _reward - ) + self._hanging_rewards_end[agent_id] += _reward # If agent is stepping, add timestep to `SingleAgentEpisode`. if _observation is not None: @@ -850,7 +848,8 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": env_t_started=self.env_t, agent_t_started={ aid: self.agent_episodes[aid].t - for aid in self.agent_ids if not self.agent_episodes[aid].is_done + for aid in self.agent_ids + if not self.agent_episodes[aid].is_done }, # Same AgentIDs and SingleAgentEpisode IDs. agent_episode_ids=self.agent_episode_ids, @@ -858,7 +857,6 @@ def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": aid: self.agent_episodes[aid].module_id for aid in self.agent_ids }, agent_to_module_mapping_fn=self.agent_to_module_mapping_fn, - # All data we provided to the c'tor goes into the lookback buffer. len_lookback_buffer="auto", ) @@ -1833,11 +1831,13 @@ def _init_single_agent_episodes( len(observations_per_agent[agent_id]) - 1 ) - # Those agents that did NOT step get self.SKIP_ENV_TS_TAG added to their - # mapping. + # Those agents that did NOT step: + # - Get self.SKIP_ENV_TS_TAG added to their env_t_to_agent_t mapping. + # - Get their reward (if any) added up. for agent_id in all_agent_ids: if agent_id not in obs and agent_id not in done_per_agent: self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG) + self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0) # Update per-agent lookback buffer sizes to be used when creating the # indiviual `SingleAgentEpisode` objects below. @@ -2431,6 +2431,10 @@ def _get_hanging_value(self, what: str, agent_id: AgentID) -> Any: def _del_hanging(self, agent_id: AgentID) -> None: """Deletes all hanging action, reward, extra_model_outputs of given agent.""" + self._hanging_actions_begin.pop(agent_id, None) + self._hanging_extra_model_outputs_begin.pop(agent_id, None) + self._hanging_rewards_begin.pop(agent_id, None) + self._hanging_actions_end.pop(agent_id, None) self._hanging_extra_model_outputs_end.pop(agent_id, None) self._hanging_rewards_end.pop(agent_id, None) diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index 6bec9d9bf25b..b86afd6fdaf1 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -2211,11 +2211,13 @@ def test_other_getters(self): def test_cut(self): # Simple multi-agent episode, in which all agents always step. - episode = self._create_simple_episode([ - {"a0": 0, "a1": 0}, - {"a0": 1, "a1": 1}, - {"a0": 2, "a1": 2}, - ]) + episode = self._create_simple_episode( + [ + {"a0": 0, "a1": 0}, + {"a0": 1, "a1": 1}, + {"a0": 2, "a1": 2}, + ] + ) successor = episode.cut() check(len(successor), 0) check(successor.env_t_started, 2) @@ -2234,12 +2236,15 @@ def test_cut(self): check(successor._hanging_extra_model_outputs_end, {}) # Multi-agent episode with lookback buffer, in which all agents always step. - episode = self._create_simple_episode([ - {"a0": 0, "a1": 0}, - {"a0": 1, "a1": 1}, - {"a0": 2, "a1": 2}, - {"a0": 3, "a1": 3}, - ], len_lookback_buffer=2) + episode = self._create_simple_episode( + [ + {"a0": 0, "a1": 0}, + {"a0": 1, "a1": 1}, + {"a0": 2, "a1": 2}, + {"a0": 3, "a1": 3}, + ], + len_lookback_buffer=2, + ) # Cut with lookback=0 argument (default). successor = episode.cut() check(len(successor), 0) @@ -2278,19 +2283,30 @@ def test_cut(self): check(successor._hanging_rewards_end, {}) check(successor._hanging_extra_model_outputs_end, {}) - # Multi-agent episode, in which one agent has a long sequence of not acting. - episode = self._create_simple_episode([ + # Multi-agent episode, in which one agent has a long sequence of not acting, + # but does receive (intermittend/hanging) rewards during this time. + observations = [ {"a0": 0, "a1": 0}, # 0 {"a0": 1}, # 1 {"a0": 2}, # 2 {"a0": 3}, # 3 - ]) + ] + episode = MultiAgentEpisode( + observations=observations, + actions=observations[:-1], + rewards=[ + {"a0": 0.0, "a1": 0.0}, # 0 + {"a0": 0.1, "a1": 0.1}, # 1 + {"a0": 0.2, "a1": 0.2}, # 2 + ], + len_lookback_buffer=0, + ) successor = episode.cut() check(len(successor), 0) check(successor.env_t_started, 3) check(successor.env_t, 3) a0 = successor.agent_episodes["a0"] - self.assertTrue("a1" not in successor.agent_episodes) + self.assertTrue("a1" not in successor.agent_episodes) check(len(a0), 0) check(a0.t_started, 3) check(a0.t, 3) @@ -2298,20 +2314,48 @@ def test_cut(self): check(a0.actions, []) check(a0.rewards, []) check(successor._hanging_actions_begin, {"a1": 0}) - check(successor._hanging_rewards_begin, {"a1": 0.0}) + check(successor._hanging_rewards_begin, {"a1": 0.3}) check(successor._hanging_extra_model_outputs_begin, {"a1": {}}) check(successor._hanging_actions_end, {}) - check(successor._hanging_rewards_end, {}) + check(successor._hanging_rewards_end, {"a1": 0.0}) + check(successor._hanging_extra_model_outputs_end, {}) + # Add a few timesteps to successor and test the resulting episode. + successor.add_env_step( + observations={"a0": 4}, + actions={"a0": 3}, + rewards={"a0": 0.3, "a1": 0.3}, + ) + check(len(successor), 1) + check(successor.env_t_started, 3) + check(successor.env_t, 4) + # Just b/c we added an intermittend reward for a1 does not mean it should + # already have a SAEps in `successor`. It still hasn't received its first obs + # yet after the cut. + self.assertTrue("a1" not in successor.agent_episodes) + check(len(a0), 1) + check(a0.t_started, 3) + check(a0.t, 4) + check(a0.observations, [3, 4]) + check(a0.actions, [3]) + check(a0.rewards, [0.3]) + check(successor._hanging_actions_begin, {"a1": 0}) + check(successor._hanging_rewards_begin, {"a1": 0.3}) + check(successor._hanging_extra_model_outputs_begin, {"a1": {}}) + check(successor._hanging_actions_end, {}) + check(successor._hanging_rewards_end, {"a1": 0.3}) check(successor._hanging_extra_model_outputs_end, {}) # Generate a simple multi-agent episode and check all internals after # construction. - episode_1 = self._create_simple_episode([ - {"a0": 0, "a1": 0}, - {"a1": 1}, - {"a1": 2}, - {"a1": 3}, - ], len_lookback_buffer="auto") + episode_1 = self._create_simple_episode( + [ + {"a0": 0, "a1": 0}, + {"a1": 1}, + {"a1": 2}, + {"a1": 3}, + ], + len_lookback_buffer="auto", + ) episode_2 = episode_1.cut() check(episode_1.id_, episode_2.id_) check(len(episode_1), 0) @@ -3219,7 +3263,7 @@ def test_get_return(self): # `get_return()`. buffered_rewards = sum(episode._hanging_rewards_end.values()) self.assertTrue( - episode.get_return(consider_hanging_rewards=True), + episode.get_return(include_hanging_rewards=True), agent_returns + buffered_rewards, ) From 02148b9dbc3c4c731cf7e52cee2c6ff6296ae3b6 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 11 Apr 2024 17:36:26 +0200 Subject: [PATCH 8/9] wip Signed-off-by: sven1977 --- rllib/env/multi_agent_episode.py | 17 ++++++++++--- rllib/env/tests/test_multi_agent_episode.py | 28 +++++++++++++++++++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 8b50ce290475..16242ce1e346 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -422,14 +422,15 @@ def add_env_step( ) - {"__all__"} for agent_id in agent_ids_with_data: if agent_id not in self.agent_episodes: - self.agent_episodes[agent_id] = SingleAgentEpisode( + sa_episode = SingleAgentEpisode( agent_id=agent_id, module_id=self.module_for(agent_id), multi_agent_episode_id=self.id_, observation_space=self.observation_space.get(agent_id), action_space=self.action_space.get(agent_id), ) - sa_episode: SingleAgentEpisode = self.agent_episodes[agent_id] + else: + sa_episode = self.agent_episodes.get(agent_id) # Collect value to be passed (at end of for-loop) into `add_env_step()` # call. @@ -489,7 +490,6 @@ def add_env_step( agent_id, None ) _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward - # _agent_step = len(sa_episode) # First observation for this agent, we have no hanging action. # ... [done]? ... -> [1st obs for agent ID] else: @@ -506,6 +506,10 @@ def add_env_step( ) # Make `add_env_reset` call and continue with next agent. sa_episode.add_env_reset(observation=_observation, infos=_infos) + # Add possible reward to begin cache. + self._hanging_rewards_begin[agent_id] += _reward + # Now that the SAEps is valid, add it to our dict. + self.agent_episodes[agent_id] = sa_episode continue # CASE 3: Step is started (by an action), but not completed (no next obs). @@ -584,7 +588,12 @@ def add_env_step( _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward # The agent is still alive, just add current reward to cache. else: - self._hanging_rewards_end[agent_id] += _reward + # But has never stepped in this episode -> add to begin cache. + if agent_id not in self.agent_episodes: + self._hanging_rewards_begin[agent_id] += _reward + # Otherwise, add to end cache. + else: + self._hanging_rewards_end[agent_id] += _reward # If agent is stepping, add timestep to `SingleAgentEpisode`. if _observation is not None: diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index b86afd6fdaf1..acecf59324a0 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -2339,10 +2339,34 @@ def test_cut(self): check(a0.actions, [3]) check(a0.rewards, [0.3]) check(successor._hanging_actions_begin, {"a1": 0}) - check(successor._hanging_rewards_begin, {"a1": 0.3}) + check(successor._hanging_rewards_begin, {"a1": 0.6}) + check(successor._hanging_extra_model_outputs_begin, {"a1": {}}) + check(successor._hanging_actions_end, {}) + check(successor._hanging_rewards_end, {"a1": 0.0}) + check(successor._hanging_extra_model_outputs_end, {}) + # Now a1 actually does receive its next obs. + successor.add_env_step( + observations={"a0": 5, "a1": 5}, # <- this is a1's 1st obs in this chunk + actions={"a0": 4}, + rewards={"a0": 0.4, "a1": 0.4}, + ) + check(len(successor), 2) + check(successor.env_t_started, 3) + check(successor.env_t, 5) + a1 = successor.agent_episodes["a1"] + check((len(a0), len(a1)), (2, 0)) + check((a0.t_started, a1.t_started), (3, 0)) + check((a0.t, a1.t), (5, 0)) + check((a0.observations, a1.observations), ([3, 4, 5], [5])) + check((a0.actions, a1.actions), ([3, 4], [])) + check((a0.rewards, a1.rewards), ([0.3, 0.4], [])) + # Begin caches keep accumulating a1's rewards. + check(successor._hanging_actions_begin, {"a1": 0}) + check(successor._hanging_rewards_begin, {"a1": 1.0}) check(successor._hanging_extra_model_outputs_begin, {"a1": {}}) + # But end caches are now empty (due to a1's observation/finished step). check(successor._hanging_actions_end, {}) - check(successor._hanging_rewards_end, {"a1": 0.3}) + check(successor._hanging_rewards_end, {"a1": 0.0}) check(successor._hanging_extra_model_outputs_end, {}) # Generate a simple multi-agent episode and check all internals after From f02fe5647e73185269611b033ef7d282cf9016b3 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 12 Apr 2024 07:14:02 +0200 Subject: [PATCH 9/9] wip Signed-off-by: sven1977 --- rllib/env/tests/test_multi_agent_episode.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index acecf59324a0..ec7d5c406d2d 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -508,13 +508,14 @@ def test_add_env_step(self): terminateds=terminated, truncateds=truncated, ) - # Assert that the action buffer for agent 4 is full. + # Assert that the action cache for agent 4 is used. # Note, agent 4 acts, but receives no observation. - # Note also, all other buffers are always full, due to their defaults. + # Note also, all other caches are always used, due to their defaults. self.assertTrue(episode._hanging_actions_end["agent_4"] is not None) - # Assert that the reward buffers of agents 3 and 5 are at 1.0. + # Assert that the reward caches of agents 3 and 5 are there. + # For agent_5 (b/c it has never done anything), we add to the begin cache. check(episode._hanging_rewards_end["agent_3"], 2.2) - check(episode._hanging_rewards_end["agent_5"], 1.0) + check(episode._hanging_rewards_begin["agent_5"], 1.0) def test_get_observations(self): # Generate simple records for a multi agent environment.