Skip to content

Commit

Permalink
[RLlib] [Connectors] Fix test returning model based rollouts data con…
Browse files Browse the repository at this point in the history
…nectors (ray-project#30308)

test_returning_model_based_rollouts_data is only compatible with the old episode api

Signed-off-by: Avnish <[email protected]>
Co-authored-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
2 people authored and WeichenXu123 committed Dec 19, 2022
1 parent 47e68ae commit 40ddb00
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions rllib/env/tests/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def is_recurrent(self):
self.assertEqual(batch["state_out_0"][i], h)

def test_returning_model_based_rollouts_data(self):
# TODO(avnishn): This test only works with the old api

class ModelBasedPolicy(DQNTFPolicy):
def compute_actions_from_input_dict(
self, input_dict, explore=None, timestep=None, episodes=None, **kwargs
Expand Down Expand Up @@ -416,6 +418,7 @@ def compute_actions_from_input_dict(
.rollouts(
rollout_fragment_length=5,
num_rollout_workers=0,
enable_connectors=False, # only works with old episode API
)
.multi_agent(
policies={"p0", "p1"},
Expand Down

0 comments on commit 40ddb00

Please sign in to comment.