Skip to content

Commit

Permalink
[RLlib] Fix test case multi-agent Pendulum PPO (--num-agents=2 arg mi…
Browse files Browse the repository at this point in the history
…ssing in BUILD). (#45820)
  • Loading branch information
sven1977 authored Jun 10, 2024
1 parent 13940c9 commit 829ac71
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "torch_only"],
size = "large", # bazel may complain about it being too long sometimes - large is on purpose as some frameworks take longer
srcs = ["tuned_examples/ppo/multi_agent_pendulum_ppo.py"],
args = ["--as-test", "--enable-new-api-stack"]
args = ["--enable-new-api-stack", "--num-agents=2", "--as-test"]
)

#@OldAPIStack
Expand Down
23 changes: 7 additions & 16 deletions rllib/algorithms/bc/tests/test_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_bc_compilation_and_learning_from_offline_file(self):
.offline_data(input_=[data_file])
)
num_iterations = 350
min_reward = 75.0
min_return_to_reach = 75.0

# Test for RLModule API and ModelV2.
for rl_modules in [True, False]:
Expand Down Expand Up @@ -79,29 +79,20 @@ def test_bc_compilation_and_learning_from_offline_file(self):

eval_results = results.get("evaluation")
if eval_results:
print(
"iter={} R={}".format(
i,
eval_results[
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
],
)
)
mean_return = eval_results[ENV_RUNNER_RESULTS][
EPISODE_RETURN_MEAN
]
print("iter={} R={}".format(i, mean_return))
# Learn until good reward is reached in the actual env.
if (
eval_results[
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
]
> min_reward
):
if mean_return > min_return_to_reach:
print("learnt!")
learnt = True
break

if not learnt:
raise ValueError(
"`BC` did not reach {} reward from expert offline "
"data!".format(min_reward)
"data!".format(min_return_to_reach)
)

check_compute_single_action(algo, include_prev_action_reward=True)
Expand Down

0 comments on commit 829ac71

Please sign in to comment.