Skip to content

Commit

Permalink
[RLlib] MultiAgentEpisode: Fix various bugs in slice(). (#44594)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Apr 10, 2024
1 parent 8888f9b commit 3fea138
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 108 deletions.
39 changes: 26 additions & 13 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,24 +1407,33 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode":
"episode have the exact same size!"
)

# Determine terminateds/truncateds.
# Determine terminateds/truncateds and when (in agent timesteps) the
# single-agent episode slices start.
terminateds = {}
truncateds = {}
agent_t_started = {}
for aid, sa_episode in self.agent_episodes.items():
mapping = self.env_t_to_agent_t[aid]
# If the (agent) timestep directly at the slice stop boundary is equal to
# the length of the single-agent episode of this agent -> Use the
# single-agent episode's terminated/truncated flags.
# If `stop` is already beyond this agents single-agent episode, then we
# If `stop` is already beyond this agent's single-agent episode, then we
# don't have to keep track of this: The MultiAgentEpisode initializer will
# automatically determine that this agent must be done (b/c has no action
# automatically determine that this agent must be done (b/c it has no action
# following its final observation).
if (
stop < len(self.env_t_to_agent_t[aid])
and self.env_t_to_agent_t[aid][stop] != self.SKIP_ENV_TS_TAG
and len(sa_episode) == self.env_t_to_agent_t[aid][stop]
stop < len(mapping)
and mapping[stop] != self.SKIP_ENV_TS_TAG
and len(sa_episode) == mapping[stop]
):
terminateds[aid] = sa_episode.is_terminated
truncateds[aid] = sa_episode.is_truncated
# Determine this agent's t_started.
if start < len(mapping):
for i in range(start, len(mapping)):
if mapping[i] != self.SKIP_ENV_TS_TAG:
agent_t_started[aid] = sa_episode.t_started + mapping[i]
break
terminateds["__all__"] = all(
terminateds.get(aid) for aid in self.agent_episodes
)
Expand Down Expand Up @@ -1465,14 +1474,16 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode":
terminateds=terminateds,
truncateds=truncateds,
len_lookback_buffer=ref_lookback,
env_t_started=self.env_t_started + start,
agent_episode_ids={
aid: eid.id_ for aid, eid in self.agent_episodes.items()
},
agent_t_started=agent_t_started,
agent_module_ids=self._agent_to_module_mapping,
agent_to_module_mapping_fn=self.agent_to_module_mapping_fn,
)

# Finalize slice if `self` is finalized.
# Finalize slice if `self` is also finalized.
if self.is_finalized:
ma_episode.finalize()

Expand Down Expand Up @@ -1525,7 +1536,7 @@ def print(self) -> None:
rows.append(row.rstrip())

# Join all components into a final string
return header + "\n".join(rows)
print(header + "\n".join(rows))

def get_state(self) -> Dict[str, Any]:
"""Returns the state of a multi-agent episode.
Expand Down Expand Up @@ -1722,7 +1733,7 @@ def _init_single_agent_episodes(
rewards = []
extra_model_outputs = []

# Infos and extra_model_outputs are allowed to be None -> Fill them with
# Infos and `extra_model_outputs` are allowed to be None -> Fill them with
# proper dummy values, if so.
if infos is None:
infos = [{} for _ in range(len(observations))]
Expand All @@ -1743,7 +1754,7 @@ def _init_single_agent_episodes(
agent_module_ids = agent_module_ids or {}

# Step through all observations and interpret these as the (global) env steps.
env_t = self.env_t - len_lookback_buffer
env_t = self.env_t_started - len_lookback_buffer
for data_idx, (obs, inf) in enumerate(zip(observations, infos)):
# If we do have actions/extra outs/rewards for this timestep, use the data.
# It may be that these lists have the same length as the observations list,
Expand Down Expand Up @@ -1800,7 +1811,7 @@ def _init_single_agent_episodes(
elif terminateds.get(agent_id) or truncateds.get(agent_id):
done_per_agent[agent_id] = True
# There is more (global) action/reward data. This agent must therefore
# be done. Auto-add it to terminateds.
# be done. Automatically add it to `done_per_agent` and `terminateds`.
elif data_idx < len(observations) - 1:
done_per_agent[agent_id] = terminateds[agent_id] = True

Expand All @@ -1809,12 +1820,14 @@ def _init_single_agent_episodes(
len(observations_per_agent[agent_id]) - 1
)

# Those agents that did NOT step get None added to their mapping.
# Those agents that did NOT step get self.SKIP_ENV_TS_TAG added to their
# mapping.
for agent_id in all_agent_ids:
if agent_id not in obs and agent_id not in done_per_agent:
self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG)

# Update per-agent lookback buffer and t_started counters.
# Update per-agent lookback buffer sizes to be used when creating the
# indiviual `SingleAgentEpisode` objects below.
for agent_id in all_agent_ids:
if env_t < self.env_t_started:
if agent_id not in done_per_agent:
Expand Down
106 changes: 53 additions & 53 deletions rllib/env/single_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,59 +349,6 @@ def __init__(
# Validate the episode data thus far.
self.validate()

def concat_episode(self, episode_chunk: "SingleAgentEpisode") -> None:
"""Adds the given `episode_chunk` to the right side of self.
In order for this to work, both chunks (`self` and `episode_chunk`) must fit
together. This is checked by the IDs (must be identical), the time step counters
(`self.t` must be the same as `episode_chunk.t_started`), as well as the
observations/infos at the concatenation boundaries (`self.observations[-1]`
must match `episode_chunk.observations[0]`). Also, `self.is_done` must not be
True, meaning `self.is_terminated` and `self.is_truncated` are both False.
Args:
episode_chunk: Another `SingleAgentEpisode` to be concatenated.
Returns: A `SingleAegntEpisode` instance containing the concatenated
from both episodes.
"""
assert episode_chunk.id_ == self.id_
# NOTE (sven): This is what we agreed on. As the replay buffers must be
# able to concatenate.
assert not self.is_done
# Make sure the timesteps match.
assert self.t == episode_chunk.t_started

episode_chunk.validate()

# Make sure, end matches other episode chunk's beginning.
assert np.all(episode_chunk.observations[0] == self.observations[-1])
# Pop out our last observations and infos (as these are identical
# to the first obs and infos in the next episode).
self.observations.pop()
self.infos.pop()

# Extend ourselves. In case, episode_chunk is already terminated (and numpyfied)
# we need to convert to lists (as we are ourselves still filling up lists).
self.observations.extend(episode_chunk.get_observations())
self.actions.extend(episode_chunk.get_actions())
self.rewards.extend(episode_chunk.get_rewards())
self.infos.extend(episode_chunk.get_infos())
self.t = episode_chunk.t

if episode_chunk.is_terminated:
self.is_terminated = True
elif episode_chunk.is_truncated:
self.is_truncated = True

for model_out_key in episode_chunk.extra_model_outputs.keys():
self.extra_model_outputs[model_out_key].extend(
episode_chunk.get_extra_model_outputs(model_out_key)
)

# Validate.
self.validate()

def add_env_reset(
self,
observation: ObsType,
Expand Down Expand Up @@ -637,6 +584,59 @@ def finalize(self) -> "SingleAgentEpisode":

return self

def concat_episode(self, other: "SingleAgentEpisode") -> None:
"""Adds the given `other` SingleAgentEpisode to the right side of self.
In order for this to work, both chunks (`self` and `other`) must fit
together. This is checked by the IDs (must be identical), the time step counters
(`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the
observations/infos at the concatenation boundaries. Also, `self.is_done` must
not be True, meaning `self.is_terminated` and `self.is_truncated` are both
False.
Args:
other: The other `SingleAgentEpisode` to be concatenated to this one.
Returns: A `SingleAgentEpisode` instance containing the concatenated data
from both episodes (`self` and `other`).
"""
assert other.id_ == self.id_
# NOTE (sven): This is what we agreed on. As the replay buffers must be
# able to concatenate.
assert not self.is_done
# Make sure the timesteps match.
assert self.t == other.t_started
# Validate `other`.
other.validate()

# Make sure, end matches other episode chunk's beginning.
assert np.all(other.observations[0] == self.observations[-1])
# Pop out our last observations and infos (as these are identical
# to the first obs and infos in the next episode).
self.observations.pop()
self.infos.pop()

# Extend ourselves. In case, episode_chunk is already terminated (and finalized)
# we need to convert to lists (as we are ourselves still filling up lists).
self.observations.extend(other.get_observations())
self.actions.extend(other.get_actions())
self.rewards.extend(other.get_rewards())
self.infos.extend(other.get_infos())
self.t = other.t

if other.is_terminated:
self.is_terminated = True
elif other.is_truncated:
self.is_truncated = True

for model_out_key in other.extra_model_outputs.keys():
self.extra_model_outputs[model_out_key].extend(
other.get_extra_model_outputs(model_out_key)
)

# Validate.
self.validate()

def cut(self, len_lookback_buffer: int = 0) -> "SingleAgentEpisode":
"""Returns a successor episode chunk (of len=0) continuing from this Episode.
Expand Down
Loading

0 comments on commit 3fea138

Please sign in to comment.