diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index 8a526957e21a..fe9194468751 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -359,6 +359,8 @@ def is_recurrent(self): self.assertEqual(batch["state_out_0"][i], h) def test_returning_model_based_rollouts_data(self): + # TODO(avnishn): This test only works with the old api + class ModelBasedPolicy(DQNTFPolicy): def compute_actions_from_input_dict( self, input_dict, explore=None, timestep=None, episodes=None, **kwargs @@ -416,6 +418,7 @@ def compute_actions_from_input_dict( .rollouts( rollout_fragment_length=5, num_rollout_workers=0, + enable_connectors=False, # only works with old episode API ) .multi_agent( policies={"p0", "p1"},