-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix MultiAgentEpisode getter bugs. #44898
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -975,7 +975,7 @@ def test_get_actions(self): | |
check(act, actions[i]) | ||
# Access >=0 integer indices (expect index error as everything is in | ||
# lookback buffer). | ||
for i in range(1, 5): | ||
for i in range(0, 5): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idx=0 was NOT working properly before this fix. |
||
with self.assertRaises(IndexError): | ||
episode.get_actions(i) | ||
# Access <= -5 integer indices (expect index error as this goes beyond length of | ||
|
@@ -1023,6 +1023,50 @@ def test_get_actions(self): | |
act = episode.get_actions(-4, env_steps=False, fill="skip") | ||
check(act, {"a0": "skip", "a1": 0}) | ||
|
||
episode.add_env_step( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this to the tests. Mostly to figure out, whether a hanging action at the edge of the episode or further back would make a difference in |
||
observations={"a0": 5, "a1": 5}, actions={"a1": 4}, rewards={"a1": 4} | ||
) | ||
check(episode.get_actions(0), {"a1": 4}) | ||
check(episode.get_actions(-1), {"a1": 4}) | ||
check(episode.get_actions(-2), {"a1": 3}) | ||
episode.add_env_step( | ||
observations={"a1": 6}, | ||
actions={"a0": 5, "a1": 5}, | ||
rewards={"a0": 5, "a1": 5}, | ||
) | ||
check(episode.get_actions(0), {"a1": 4}) | ||
check(episode.get_actions(1), {"a0": 5, "a1": 5}) | ||
check(episode.get_actions(-1), {"a0": 5, "a1": 5}) | ||
|
||
# Generate a simple multi-agent episode, where a hanging action is at the end. | ||
observations = [ | ||
{"a0": 0, "a1": 0}, | ||
{"a0": 0, "a1": 1}, | ||
{"a0": 2}, | ||
] | ||
actions = [{"a0": 0, "a1": 0}, {"a0": 1, "a1": 1}] | ||
rewards = [{"a0": 0.0, "a1": 0.0}, {"a0": 0.1, "a1": 0.1}] | ||
episode = MultiAgentEpisode( | ||
observations=observations, | ||
actions=actions, | ||
rewards=rewards, | ||
len_lookback_buffer=0, | ||
) | ||
# Test, whether the hanging action of a1 at the end gets returned properly | ||
# for idx=-1. | ||
act = episode.get_actions(-1) | ||
check(act, {"a0": 1, "a1": 1}) | ||
act = episode.get_actions(-2) | ||
check(act, {"a0": 0, "a1": 0}) | ||
act = episode.get_actions(0) | ||
check(act, {"a0": 0, "a1": 0}) | ||
act = episode.get_actions(1) | ||
check(act, {"a0": 1, "a1": 1}) | ||
with self.assertRaises(IndexError): | ||
episode.get_actions(2) | ||
with self.assertRaises(IndexError): | ||
episode.get_actions(-3) | ||
|
||
# Generate a simple multi-agent episode, where one agent is done. | ||
# observations = [ | ||
# {"a0": 0, "a1": 0}, | ||
|
@@ -1132,15 +1176,23 @@ def test_get_actions(self): | |
check( | ||
act, | ||
{ | ||
"agent_1": [-10, 0], | ||
"agent_2": [-10, 0], | ||
"agent_3": [-10, 0], | ||
"agent_4": [-10, -10], | ||
"agent_1": [0, 1], | ||
"agent_2": [0, -10], | ||
"agent_3": [0, 1], | ||
"agent_4": [-10, 1], | ||
}, | ||
) | ||
# Same, but w/o fill. | ||
act = episode.get_actions(indices=[-2, -1], neg_indices_left_of_zero=True) | ||
check( | ||
act, | ||
{ | ||
"agent_1": [0, 1], | ||
"agent_2": [0], | ||
"agent_3": [0, 1], | ||
"agent_4": [1], | ||
}, | ||
) | ||
# Same, but w/o fill (should produce error as the lookback is only 1 long). | ||
with self.assertRaises(IndexError): | ||
episode.get_actions(indices=[-2, -1], neg_indices_left_of_zero=True) | ||
|
||
# Get last actions for each individual agent. | ||
act = episode.get_actions(indices=-1, env_steps=False) | ||
|
@@ -1158,7 +1210,7 @@ def test_get_actions(self): | |
act = episode.get_actions(-1, env_steps=False, agent_ids=["agent_1", "agent_2"]) | ||
check(act, {"agent_1": 1, "agent_2": 0}) | ||
act = episode.get_actions(-2, env_steps=True, agent_ids={"agent_4"}) | ||
check(act, {"agent_4": 1}) | ||
check(act, {}) | ||
act = episode.get_actions([-1, -2], env_steps=True, agent_ids={"agent_4"}) | ||
check(act, {"agent_4": [1]}) | ||
# Agent 4 has only acted 2x, so there is no (local) ts=-2 for it. | ||
|
@@ -1173,7 +1225,7 @@ def test_get_actions(self): | |
# actions are in these buffers (and won't get returned here). | ||
act = episode.get_actions(return_list=True) | ||
self.assertTrue(act == []) | ||
# Expect error when calling with env_steps=False. | ||
# Expect error when calling with env_steps=False AND return_list=True. | ||
with self.assertRaises(ValueError): | ||
episode.get_actions(env_steps=False, return_list=True) | ||
# List of indices. | ||
|
@@ -1364,15 +1416,14 @@ def test_get_rewards(self): | |
check( | ||
rew, | ||
{ | ||
"agent_1": [-10, 0.5], | ||
"agent_2": [-10, 0.6], | ||
"agent_3": [-10, 0.7], | ||
"agent_4": [-10, -10], | ||
"agent_1": [0.5, 1.1], | ||
"agent_2": [0.6, -10], | ||
"agent_3": [0.7, 1.2], | ||
"agent_4": [-10, 1.3], | ||
}, | ||
) | ||
# Same, but w/o fill (should produce error as the lookback is only 1 long). | ||
with self.assertRaises(IndexError): | ||
episode.get_rewards(indices=[-2, -1], neg_indices_left_of_zero=True) | ||
# Same, but w/o fill. | ||
episode.get_rewards(indices=[-2, -1], neg_indices_left_of_zero=True) | ||
|
||
# Get last rewards for each individual agent. | ||
rew = episode.get_rewards(indices=-1, env_steps=False) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -464,7 +464,7 @@ def _get_int_index( | |
# If index >= 0 -> Ignore lookback buffer. | ||
# Otherwise, include lookback buffer. | ||
if idx >= 0 or neg_indices_left_of_zero: | ||
idx = self.lookback + idx - (_ignore_last_ts is True) | ||
idx = self.lookback + idx | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the actual bug fix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Amazing how such a small modification changes the landscape completely. |
||
# Negative indices mean: Go to left into lookback buffer starting from idx=0. | ||
# But if we pass the lookback buffer, the index should be invalid and we will | ||
# have to fill, if required. Invalidate the index by setting it to one larger | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made this a little nicer. :)