From c0ade5f0b7cfc9aeba46cde7af3b36068a6420df Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 8 Mar 2022 18:55:48 +0530 Subject: [PATCH] [RLlib] Issue 22625: `MultiAgentBatch.timeslices()` does not behave as expected. (#22657) --- rllib/BUILD | 7 + rllib/policy/sample_batch.py | 12 +- rllib/policy/tests/test_multi_agent_batch.py | 241 +++++++++++++++++++ rllib/policy/tests/test_sample_batch.py | 2 +- rllib/utils/test_utils.py | 80 ++++++ 5 files changed, 335 insertions(+), 7 deletions(-) create mode 100644 rllib/policy/tests/test_multi_agent_batch.py diff --git a/rllib/BUILD b/rllib/BUILD index af550cbf7e45..f89d87af844e 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1344,6 +1344,13 @@ py_test( srcs = ["policy/tests/test_compute_log_likelihoods.py"] ) +py_test( + name = "policy/tests/test_multi_agent_batch", + tags = ["team:ml", "policy"], + size = "small", + srcs = ["policy/tests/test_multi_agent_batch.py"] +) + py_test( name = "policy/tests/test_policy", tags = ["team:ml", "policy"], diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 022451d94b64..adebca69f66b 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -312,7 +312,7 @@ def copy(self, shallow: bool = False) -> "SampleBatch": def rows(self) -> Iterator[Dict[str, TensorType]]: """Returns an iterator over data rows, i.e. dicts with column values. - Note that if `seq_lens` is set in self, we set it to [1] in the rows. + Note that if `seq_lens` is set in self, we set it to 1 in the rows. Yields: The column values of the row in this iteration. @@ -325,13 +325,12 @@ def rows(self) -> Iterator[Dict[str, TensorType]]: ... }) >>> for row in batch.rows(): print(row) - {"a": 1, "b": 4, "seq_lens": [1]} - {"a": 2, "b": 5, "seq_lens": [1]} - {"a": 3, "b": 6, "seq_lens": [1]} + {"a": 1, "b": 4, "seq_lens": 1} + {"a": 2, "b": 5, "seq_lens": 1} + {"a": 3, "b": 6, "seq_lens": 1} """ - # Do we add seq_lens=[1] to each row? - seq_lens = None if self.get(SampleBatch.SEQ_LENS) is None else np.array([1]) + seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1 self_as_dict = {k: v for k, v in self.items()} @@ -1182,6 +1181,7 @@ def finish_slice(): {k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size ) cur_slice_size = 0 + cur_slice.clear() finished_slices.append(batch) # For each unique env timestep. diff --git a/rllib/policy/tests/test_multi_agent_batch.py b/rllib/policy/tests/test_multi_agent_batch.py new file mode 100644 index 000000000000..b1047e6114a5 --- /dev/null +++ b/rllib/policy/tests/test_multi_agent_batch.py @@ -0,0 +1,241 @@ +import unittest + +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils.test_utils import check_same_batch + + +class TestMultiAgentBatch(unittest.TestCase): + def test_timeslices_non_overlapping_experiences(self): + """Tests if timeslices works as expected on a MultiAgentBatch + consisting of two non-overlapping SampleBatches. + """ + + def _generate_data(agent_idx): + batch = SampleBatch( + { + SampleBatch.T: [0, 1], + SampleBatch.EPS_ID: 2 * [agent_idx], + SampleBatch.AGENT_INDEX: 2 * [agent_idx], + SampleBatch.SEQ_LENS: [2], + } + ) + return batch + + policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))} + ma_batch = MultiAgentBatch(policy_batches, 4) + sliced_ma_batches = ma_batch.timeslices(1) + + [ + check_same_batch(i, j) + for i, j in zip( + sliced_ma_batches, + [ + MultiAgentBatch( + { + "0": SampleBatch( + { + SampleBatch.T: [0], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [0], + SampleBatch.SEQ_LENS: [1], + } + ) + }, + 1, + ), + MultiAgentBatch( + { + "0": SampleBatch( + { + SampleBatch.T: [1], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [0], + SampleBatch.SEQ_LENS: [1], + } + ) + }, + 1, + ), + MultiAgentBatch( + { + "1": SampleBatch( + { + SampleBatch.T: [0], + SampleBatch.EPS_ID: [1], + SampleBatch.AGENT_INDEX: [1], + SampleBatch.SEQ_LENS: [1], + } + ) + }, + 1, + ), + MultiAgentBatch( + { + "1": SampleBatch( + { + SampleBatch.T: [1], + SampleBatch.EPS_ID: [1], + SampleBatch.AGENT_INDEX: [1], + SampleBatch.SEQ_LENS: [1], + } + ) + }, + 1, + ), + ], + ) + ] + + def test_timeslices_partially_overlapping_experiences(self): + """Tests if timeslices works as expected on a MultiAgentBatch + consisting of two partially overlapping SampleBatches. + """ + + def _generate_data(agent_idx, t_start): + batch = SampleBatch( + { + SampleBatch.T: [t_start, t_start + 1], + SampleBatch.EPS_ID: [0, 0], + SampleBatch.AGENT_INDEX: 2 * [agent_idx], + SampleBatch.SEQ_LENS: [2], + } + ) + return batch + + policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))} + ma_batch = MultiAgentBatch(policy_batches, 4) + sliced_ma_batches = ma_batch.timeslices(1) + + [ + check_same_batch(i, j) + for i, j in zip( + sliced_ma_batches, + [ + MultiAgentBatch( + { + "0": SampleBatch( + { + SampleBatch.T: [0], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [0], + SampleBatch.SEQ_LENS: [1], + } + ) + }, + 1, + ), + MultiAgentBatch( + { + "0": SampleBatch( + { + SampleBatch.T: [1], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [0], + SampleBatch.SEQ_LENS: [1], + } + ), + "1": SampleBatch( + { + SampleBatch.T: [1], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [1], + SampleBatch.SEQ_LENS: [1], + } + ), + }, + 1, + ), + MultiAgentBatch( + { + "1": SampleBatch( + { + SampleBatch.T: [2], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [1], + SampleBatch.SEQ_LENS: [1], + } + ) + }, + 1, + ), + ], + ) + ] + + def test_timeslices_fully_overlapping_experiences(self): + """Tests if timeslices works as expected on a MultiAgentBatch + consisting of two fully overlapping SampleBatches. + """ + + def _generate_data(agent_idx): + batch = SampleBatch( + { + SampleBatch.T: [0, 1], + SampleBatch.EPS_ID: [0, 0], + SampleBatch.AGENT_INDEX: 2 * [agent_idx], + SampleBatch.SEQ_LENS: [2], + } + ) + return batch + + policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))} + ma_batch = MultiAgentBatch(policy_batches, 4) + sliced_ma_batches = ma_batch.timeslices(1) + + [ + check_same_batch(i, j) + for i, j in zip( + sliced_ma_batches, + [ + MultiAgentBatch( + { + "0": SampleBatch( + { + SampleBatch.T: [0], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [0], + SampleBatch.SEQ_LENS: [1], + } + ), + "1": SampleBatch( + { + SampleBatch.T: [0], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [1], + SampleBatch.SEQ_LENS: [1], + } + ), + }, + 1, + ), + MultiAgentBatch( + { + "0": SampleBatch( + { + SampleBatch.T: [1], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [0], + SampleBatch.SEQ_LENS: [1], + } + ), + "1": SampleBatch( + { + SampleBatch.T: [1], + SampleBatch.EPS_ID: [0], + SampleBatch.AGENT_INDEX: [1], + SampleBatch.SEQ_LENS: [1], + } + ), + }, + 1, + ), + ], + ) + ] + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index 5a536091b145..dbf86aeb3ecc 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -136,7 +136,7 @@ def test_rows(self): ) check( next(s1.rows()), - {"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: [1]}, + {"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: 1}, ) def test_compression(self): diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 3fe2064991aa..f983460e0ad7 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -820,3 +820,83 @@ def should_check_eval(experiment): } return result + + +def check_same_batch(batch1, batch2) -> None: + """Check if both batches are (almost) identical. + + For MultiAgentBatches, the step count and individual policy's + SampleBatches are checked for identity. For SampleBatches, identity is + checked as the almost numerical key-value-pair identity between batches + with ray.rllib.utils.test_utils.check(). unroll_id is compared only if + both batches have an unroll_id. + + Args: + batch1: Batch to compare against batch2 + batch2: Batch to compare against batch1 + """ + # Avoids circular import + from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch + + assert type(batch1) == type( + batch2 + ), "Input batches are of different " "types {} and {}".format( + str(type(batch1)), str(type(batch2)) + ) + + def check_sample_batches(_batch1, _batch2, _policy_id=None): + unroll_id_1 = _batch1.get("unroll_id", None) + unroll_id_2 = _batch2.get("unroll_id", None) + # unroll IDs only have to fit if both batches have them + if unroll_id_1 is not None and unroll_id_2 is not None: + assert unroll_id_1 == unroll_id_2 + + batch1_keys = set() + for k, v in _batch1.items(): + # unroll_id is compared above already + if k == "unroll_id": + continue + check(v, _batch2[k]) + batch1_keys.add(k) + + batch2_keys = set(_batch2.keys()) + # unroll_id is compared above already + batch2_keys.discard("unroll_id") + _difference = batch1_keys.symmetric_difference(batch2_keys) + + # Cases where one batch has info and the other has not + if _policy_id: + assert not _difference, ( + "SampleBatches for policy with ID {} " + "don't share information on the " + "following information: \n{}" + "".format(_policy_id, _difference) + ) + else: + assert not _difference, ( + "SampleBatches don't share information " + "on the following information: \n{}" + "".format(_difference) + ) + + if type(batch1) == SampleBatch: + check_sample_batches(batch1, batch2) + elif type(batch1) == MultiAgentBatch: + assert batch1.count == batch2.count + batch1_ids = set() + for policy_id, policy_batch in batch1.policy_batches.items(): + check_sample_batches( + policy_batch, batch2.policy_batches[policy_id], policy_id + ) + batch1_ids.add(policy_id) + + # Case where one ma batch has info on a policy the other has not + batch2_ids = set(batch2.policy_batches.keys()) + difference = batch1_ids.symmetric_difference(batch2_ids) + assert ( + not difference + ), "MultiAgentBatches don't share the following" "information: \n{}.".format( + difference + ) + else: + raise ValueError("Unsupported batch type " + str(type(batch1)))