Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix MultiAgentEpisode getter bugs. #44898

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


# TODO (simon): Include cases in which the number of agents in an
# episode are shrinking or growing during the episode itself.
# episode are shrinking or growing during the episode itself.
@PublicAPI(stability="alpha")
class MultiAgentEpisode:
"""Stores multi-agent episode data.
Expand Down Expand Up @@ -1633,21 +1633,30 @@ def __repr__(self):
)

def print(self) -> None:
"""Prints this MultiAgentEpisode as a table of observations for the agents."""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made this a little nicer. :)


# Find the maximum timestep across all agents to determine the grid width.
max_ts = max(len(ts) for ts in self.env_t_to_agent_t.values())
max_ts = max(ts.len_incl_lookback() for ts in self.env_t_to_agent_t.values())
lookback = next(iter(self.env_t_to_agent_t.values())).lookback
longest_agent = max(len(aid) for aid in self.agent_ids)
# Construct the header.
header = "ts " + " ".join(str(i) for i in range(max_ts)) + "\n"
header = (
"ts"
+ (" " * longest_agent)
+ " ".join(str(i) for i in range(-lookback, max_ts - lookback))
+ "\n"
)
# Construct each agent's row.
rows = []
for agent, timesteps in self.env_t_to_agent_t.items():
row = f"{agent} "
for t in timesteps:
for agent, inf_buffer in self.env_t_to_agent_t.items():
row = f"{agent} " + (" " * (longest_agent - len(agent)))
for t in inf_buffer.data:
# Two spaces for alignment.
if t == "S":
row += " "
row += " "
# Mark the step with an x.
else:
row += "x "
row += " x "
# Remove trailing space for alignment.
rows.append(row.rstrip())

Expand Down
85 changes: 68 additions & 17 deletions rllib/env/tests/test_multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def test_get_actions(self):
check(act, actions[i])
# Access >=0 integer indices (expect index error as everything is in
# lookback buffer).
for i in range(1, 5):
for i in range(0, 5):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx=0 was NOT working properly before this fix.

with self.assertRaises(IndexError):
episode.get_actions(i)
# Access <= -5 integer indices (expect index error as this goes beyond length of
Expand Down Expand Up @@ -1023,6 +1023,50 @@ def test_get_actions(self):
act = episode.get_actions(-4, env_steps=False, fill="skip")
check(act, {"a0": "skip", "a1": 0})

episode.add_env_step(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to the tests. Mostly to figure out, whether a hanging action at the edge of the episode or further back would make a difference in get_actions(-1).

observations={"a0": 5, "a1": 5}, actions={"a1": 4}, rewards={"a1": 4}
)
check(episode.get_actions(0), {"a1": 4})
check(episode.get_actions(-1), {"a1": 4})
check(episode.get_actions(-2), {"a1": 3})
episode.add_env_step(
observations={"a1": 6},
actions={"a0": 5, "a1": 5},
rewards={"a0": 5, "a1": 5},
)
check(episode.get_actions(0), {"a1": 4})
check(episode.get_actions(1), {"a0": 5, "a1": 5})
check(episode.get_actions(-1), {"a0": 5, "a1": 5})

# Generate a simple multi-agent episode, where a hanging action is at the end.
observations = [
{"a0": 0, "a1": 0},
{"a0": 0, "a1": 1},
{"a0": 2},
]
actions = [{"a0": 0, "a1": 0}, {"a0": 1, "a1": 1}]
rewards = [{"a0": 0.0, "a1": 0.0}, {"a0": 0.1, "a1": 0.1}]
episode = MultiAgentEpisode(
observations=observations,
actions=actions,
rewards=rewards,
len_lookback_buffer=0,
)
# Test, whether the hanging action of a1 at the end gets returned properly
# for idx=-1.
act = episode.get_actions(-1)
check(act, {"a0": 1, "a1": 1})
act = episode.get_actions(-2)
check(act, {"a0": 0, "a1": 0})
act = episode.get_actions(0)
check(act, {"a0": 0, "a1": 0})
act = episode.get_actions(1)
check(act, {"a0": 1, "a1": 1})
with self.assertRaises(IndexError):
episode.get_actions(2)
with self.assertRaises(IndexError):
episode.get_actions(-3)

# Generate a simple multi-agent episode, where one agent is done.
# observations = [
# {"a0": 0, "a1": 0},
Expand Down Expand Up @@ -1132,15 +1176,23 @@ def test_get_actions(self):
check(
act,
{
"agent_1": [-10, 0],
"agent_2": [-10, 0],
"agent_3": [-10, 0],
"agent_4": [-10, -10],
"agent_1": [0, 1],
"agent_2": [0, -10],
"agent_3": [0, 1],
"agent_4": [-10, 1],
},
)
# Same, but w/o fill.
act = episode.get_actions(indices=[-2, -1], neg_indices_left_of_zero=True)
check(
act,
{
"agent_1": [0, 1],
"agent_2": [0],
"agent_3": [0, 1],
"agent_4": [1],
},
)
# Same, but w/o fill (should produce error as the lookback is only 1 long).
with self.assertRaises(IndexError):
episode.get_actions(indices=[-2, -1], neg_indices_left_of_zero=True)

# Get last actions for each individual agent.
act = episode.get_actions(indices=-1, env_steps=False)
Expand All @@ -1158,7 +1210,7 @@ def test_get_actions(self):
act = episode.get_actions(-1, env_steps=False, agent_ids=["agent_1", "agent_2"])
check(act, {"agent_1": 1, "agent_2": 0})
act = episode.get_actions(-2, env_steps=True, agent_ids={"agent_4"})
check(act, {"agent_4": 1})
check(act, {})
act = episode.get_actions([-1, -2], env_steps=True, agent_ids={"agent_4"})
check(act, {"agent_4": [1]})
# Agent 4 has only acted 2x, so there is no (local) ts=-2 for it.
Expand All @@ -1173,7 +1225,7 @@ def test_get_actions(self):
# actions are in these buffers (and won't get returned here).
act = episode.get_actions(return_list=True)
self.assertTrue(act == [])
# Expect error when calling with env_steps=False.
# Expect error when calling with env_steps=False AND return_list=True.
with self.assertRaises(ValueError):
episode.get_actions(env_steps=False, return_list=True)
# List of indices.
Expand Down Expand Up @@ -1364,15 +1416,14 @@ def test_get_rewards(self):
check(
rew,
{
"agent_1": [-10, 0.5],
"agent_2": [-10, 0.6],
"agent_3": [-10, 0.7],
"agent_4": [-10, -10],
"agent_1": [0.5, 1.1],
"agent_2": [0.6, -10],
"agent_3": [0.7, 1.2],
"agent_4": [-10, 1.3],
},
)
# Same, but w/o fill (should produce error as the lookback is only 1 long).
with self.assertRaises(IndexError):
episode.get_rewards(indices=[-2, -1], neg_indices_left_of_zero=True)
# Same, but w/o fill.
episode.get_rewards(indices=[-2, -1], neg_indices_left_of_zero=True)

# Get last rewards for each individual agent.
rew = episode.get_rewards(indices=-1, env_steps=False)
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/utils/infinite_lookback_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _get_int_index(
# If index >= 0 -> Ignore lookback buffer.
# Otherwise, include lookback buffer.
if idx >= 0 or neg_indices_left_of_zero:
idx = self.lookback + idx - (_ignore_last_ts is True)
idx = self.lookback + idx
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual bug fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing how such a small modification changes the landscape completely.

# Negative indices mean: Go to left into lookback buffer starting from idx=0.
# But if we pass the lookback buffer, the index should be invalid and we will
# have to fill, if required. Invalidate the index by setting it to one larger
Expand Down
Loading