Skip to content

Commit

Permalink
[RLlib] Issue 22625: MultiAgentBatch.timeslices() does not behave a…
Browse files Browse the repository at this point in the history
…s expected. (#22657)
  • Loading branch information
ArturNiederfahrenhorst authored Mar 8, 2022
1 parent 4576f53 commit c0ade5f
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 7 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,13 @@ py_test(
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
)

py_test(
name = "policy/tests/test_multi_agent_batch",
tags = ["team:ml", "policy"],
size = "small",
srcs = ["policy/tests/test_multi_agent_batch.py"]
)

py_test(
name = "policy/tests/test_policy",
tags = ["team:ml", "policy"],
Expand Down
12 changes: 6 additions & 6 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def copy(self, shallow: bool = False) -> "SampleBatch":
def rows(self) -> Iterator[Dict[str, TensorType]]:
"""Returns an iterator over data rows, i.e. dicts with column values.
Note that if `seq_lens` is set in self, we set it to [1] in the rows.
Note that if `seq_lens` is set in self, we set it to 1 in the rows.
Yields:
The column values of the row in this iteration.
Expand All @@ -325,13 +325,12 @@ def rows(self) -> Iterator[Dict[str, TensorType]]:
... })
>>> for row in batch.rows():
print(row)
{"a": 1, "b": 4, "seq_lens": [1]}
{"a": 2, "b": 5, "seq_lens": [1]}
{"a": 3, "b": 6, "seq_lens": [1]}
{"a": 1, "b": 4, "seq_lens": 1}
{"a": 2, "b": 5, "seq_lens": 1}
{"a": 3, "b": 6, "seq_lens": 1}
"""

# Do we add seq_lens=[1] to each row?
seq_lens = None if self.get(SampleBatch.SEQ_LENS) is None else np.array([1])
seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1

self_as_dict = {k: v for k, v in self.items()}

Expand Down Expand Up @@ -1182,6 +1181,7 @@ def finish_slice():
{k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
)
cur_slice_size = 0
cur_slice.clear()
finished_slices.append(batch)

# For each unique env timestep.
Expand Down
241 changes: 241 additions & 0 deletions rllib/policy/tests/test_multi_agent_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import unittest

from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.test_utils import check_same_batch


class TestMultiAgentBatch(unittest.TestCase):
def test_timeslices_non_overlapping_experiences(self):
"""Tests if timeslices works as expected on a MultiAgentBatch
consisting of two non-overlapping SampleBatches.
"""

def _generate_data(agent_idx):
batch = SampleBatch(
{
SampleBatch.T: [0, 1],
SampleBatch.EPS_ID: 2 * [agent_idx],
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
SampleBatch.SEQ_LENS: [2],
}
)
return batch

policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
ma_batch = MultiAgentBatch(policy_batches, 4)
sliced_ma_batches = ma_batch.timeslices(1)

[
check_same_batch(i, j)
for i, j in zip(
sliced_ma_batches,
[
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"1": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [1],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"1": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [1],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
],
)
]

def test_timeslices_partially_overlapping_experiences(self):
"""Tests if timeslices works as expected on a MultiAgentBatch
consisting of two partially overlapping SampleBatches.
"""

def _generate_data(agent_idx, t_start):
batch = SampleBatch(
{
SampleBatch.T: [t_start, t_start + 1],
SampleBatch.EPS_ID: [0, 0],
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
SampleBatch.SEQ_LENS: [2],
}
)
return batch

policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))}
ma_batch = MultiAgentBatch(policy_batches, 4)
sliced_ma_batches = ma_batch.timeslices(1)

[
check_same_batch(i, j)
for i, j in zip(
sliced_ma_batches,
[
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
),
"1": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
),
},
1,
),
MultiAgentBatch(
{
"1": SampleBatch(
{
SampleBatch.T: [2],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
],
)
]

def test_timeslices_fully_overlapping_experiences(self):
"""Tests if timeslices works as expected on a MultiAgentBatch
consisting of two fully overlapping SampleBatches.
"""

def _generate_data(agent_idx):
batch = SampleBatch(
{
SampleBatch.T: [0, 1],
SampleBatch.EPS_ID: [0, 0],
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
SampleBatch.SEQ_LENS: [2],
}
)
return batch

policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
ma_batch = MultiAgentBatch(policy_batches, 4)
sliced_ma_batches = ma_batch.timeslices(1)

[
check_same_batch(i, j)
for i, j in zip(
sliced_ma_batches,
[
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
),
"1": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
),
},
1,
),
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
),
"1": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
),
},
1,
),
],
)
]


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
2 changes: 1 addition & 1 deletion rllib/policy/tests/test_sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_rows(self):
)
check(
next(s1.rows()),
{"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: [1]},
{"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: 1},
)

def test_compression(self):
Expand Down
80 changes: 80 additions & 0 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,3 +820,83 @@ def should_check_eval(experiment):
}

return result


def check_same_batch(batch1, batch2) -> None:
"""Check if both batches are (almost) identical.
For MultiAgentBatches, the step count and individual policy's
SampleBatches are checked for identity. For SampleBatches, identity is
checked as the almost numerical key-value-pair identity between batches
with ray.rllib.utils.test_utils.check(). unroll_id is compared only if
both batches have an unroll_id.
Args:
batch1: Batch to compare against batch2
batch2: Batch to compare against batch1
"""
# Avoids circular import
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch

assert type(batch1) == type(
batch2
), "Input batches are of different " "types {} and {}".format(
str(type(batch1)), str(type(batch2))
)

def check_sample_batches(_batch1, _batch2, _policy_id=None):
unroll_id_1 = _batch1.get("unroll_id", None)
unroll_id_2 = _batch2.get("unroll_id", None)
# unroll IDs only have to fit if both batches have them
if unroll_id_1 is not None and unroll_id_2 is not None:
assert unroll_id_1 == unroll_id_2

batch1_keys = set()
for k, v in _batch1.items():
# unroll_id is compared above already
if k == "unroll_id":
continue
check(v, _batch2[k])
batch1_keys.add(k)

batch2_keys = set(_batch2.keys())
# unroll_id is compared above already
batch2_keys.discard("unroll_id")
_difference = batch1_keys.symmetric_difference(batch2_keys)

# Cases where one batch has info and the other has not
if _policy_id:
assert not _difference, (
"SampleBatches for policy with ID {} "
"don't share information on the "
"following information: \n{}"
"".format(_policy_id, _difference)
)
else:
assert not _difference, (
"SampleBatches don't share information "
"on the following information: \n{}"
"".format(_difference)
)

if type(batch1) == SampleBatch:
check_sample_batches(batch1, batch2)
elif type(batch1) == MultiAgentBatch:
assert batch1.count == batch2.count
batch1_ids = set()
for policy_id, policy_batch in batch1.policy_batches.items():
check_sample_batches(
policy_batch, batch2.policy_batches[policy_id], policy_id
)
batch1_ids.add(policy_id)

# Case where one ma batch has info on a policy the other has not
batch2_ids = set(batch2.policy_batches.keys())
difference = batch1_ids.symmetric_difference(batch2_ids)
assert (
not difference
), "MultiAgentBatches don't share the following" "information: \n{}.".format(
difference
)
else:
raise ValueError("Unsupported batch type " + str(type(batch1)))

0 comments on commit c0ade5f

Please sign in to comment.