Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix 2 broken CI test cases: test_learner_group and cartpole_dqn_envrunner. #45110

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 28 additions & 39 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_add_remove_module(self):
),
)

self._check_multi_worker_weights(learner_group, results)
_check_multi_worker_weights(learner_group, results)

# check that module ids are updated to include the new module
module_ids_after_add = {DEFAULT_MODULE_ID, new_module_id}
Expand All @@ -260,7 +260,7 @@ def test_add_remove_module(self):
# run training without the test_module
results = learner_group.update_from_batch(batch.as_multi_agent())

self._check_multi_worker_weights(learner_group, results)
_check_multi_worker_weights(learner_group, results)

# check that module ids are updated after remove operation to not
# include the new module
Expand All @@ -272,20 +272,6 @@ def test_add_remove_module(self):
learner_group.shutdown()
del learner_group

def _check_multi_worker_weights(self, learner_group, results):
# Check that module weights are updated across workers and synchronized.
# for i in range(1, len(results)):
for module_id, mod_results in results.items():
if module_id == ALL_MODULES:
continue
# Compare the reported mean weights (merged across all Learner workers,
# which all should have the same weights after updating) with the actual
# current mean weights.
reported_mean_weights = mod_results["mean_weight"]
parameters = learner_group.get_weights(module_ids=[module_id])[module_id]
actual_mean_weights = np.mean([w.mean() for w in parameters.values()])
check(reported_mean_weights, actual_mean_weights, rtol=0.02)


class TestLearnerGroupCheckpointRestore(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -525,7 +511,6 @@ def test_async_update(self):
config = BaseTestingAlgorithmConfig().update_from_dict(config_overrides)
learner_group = config.build_learner_group(env=env)
reader = get_cartpole_dataset_reader(batch_size=512)
min_loss = float("inf")
batch = reader.next()
timer_sync = _Timer()
timer_async = _Timer()
Expand All @@ -541,8 +526,7 @@ def test_async_update(self):
# way to check that is if the time for an async update call is faster
# than the time for a sync update call.
self.assertLess(timer_async.mean, timer_sync.mean)
self.assertIsInstance(result_async, list)
self.assertEqual(len(result_async), 0)
self.assertIsInstance(result_async, dict)
iter_i = 0
while True:
batch = reader.next()
Expand All @@ -551,31 +535,36 @@ def test_async_update(self):
)
if not async_results:
continue
losses = [
np.mean(
[res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results]
)
for results in async_results
]
min_loss_this_iter = min(losses)
min_loss = min(min_loss_this_iter, min_loss)
print(
f"[iter = {iter_i}] Loss: {min_loss_this_iter:.3f}, Min Loss: "
f"{min_loss:.3f}"
)
loss = async_results[ALL_MODULES][Learner.TOTAL_LOSS_KEY]
# The loss is initially around 0.69 (ln2). When it gets to around
# 0.57 the return of the policy gets to around 100.
if min_loss < 0.57:
if loss < 0.57:
break
for results in async_results:
for res1, res2 in zip(results, results[1:]):
self.assertEqual(
res1[DEFAULT_MODULE_ID]["mean_weight"],
res2[DEFAULT_MODULE_ID]["mean_weight"],
)
# Compare reported "mean_weight" with actual ones.
# TODO (sven): Right now, we don't have any way to know, whether
# an async update result came from the most recent call to
# `learner_group.update_from_batch(async_update=True)` or an earlier
# one. Once APPO/IMPALA are properly implemented on the new API stack,
# this problem should be resolved and we can uncomment the below line.
# _check_multi_worker_weights(learner_group, async_results)
iter_i += 1
learner_group.shutdown()
self.assertLess(min_loss, 0.57)
self.assertLess(loss, 0.57)


def _check_multi_worker_weights(learner_group, results):
# Check that module weights are updated across workers and synchronized.
# for i in range(1, len(results)):
for module_id, mod_results in results.items():
if module_id == ALL_MODULES:
continue
# Compare the reported mean weights (merged across all Learner workers,
# which all should have the same weights after updating) with the actual
# current mean weights.
reported_mean_weights = mod_results["mean_weight"]
parameters = learner_group.get_weights(module_ids=[module_id])[module_id]
actual_mean_weights = np.mean([w.mean() for w in parameters.values()])
check(reported_mean_weights, actual_mean_weights, rtol=0.02)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _sample_timesteps(
if explore:
env_steps_lifetime = self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NUM_ENV_STEPS_SAMPLED_LIFETIME is then update during synching the metrics in between sampling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, it's always updated anyways (this was a bug) after each single env step. Inside self._increase_sampled_metrics().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it's also hard-overwritten by the Algo itself after each iteration in order to sum up all the steps from all the other EnvRunners (making sure each EnvRunner always has the total steps sampled, not just its own lifetime count).

) + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
)
to_env = self.module.forward_exploration(
to_module, t=env_steps_lifetime
)
Expand Down Expand Up @@ -465,7 +465,7 @@ def _sample_episodes(
if explore:
env_steps_lifetime = self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME
) + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
)
to_env = self.module.forward_exploration(
to_module, t=env_steps_lifetime
)
Expand Down
13 changes: 10 additions & 3 deletions rllib/tuned_examples/dqn/cartpole_dqn_envrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
model_config_dict={
"fcnet_hiddens": [256],
"fcnet_activation": "relu",
"epsilon": [(0, 1.0), (50000, 0.05)],
"epsilon": [(0, 1.0), (10000, 0.02)],
"fcnet_bias_initializer": "zeros_",
"post_fcnet_bias_initializer": "zeros_",
"post_fcnet_hiddens": [256],
Expand All @@ -23,7 +23,7 @@
# Settings identical to old stack.
replay_buffer_config={
"type": "PrioritizedEpisodeReplayBuffer",
"capacity": 100000,
"capacity": 50000,
"alpha": 0.6,
"beta": 0.4,
},
Expand All @@ -37,7 +37,14 @@
evaluation_parallel_to_training=True,
evaluation_num_env_runners=1,
evaluation_duration="auto",
evaluation_config={"explore": False},
evaluation_config={
"explore": False,
# TODO (sven): Add support for window=float(inf) and reduce=mean for
# evaluation episode_return_mean reductions (identical to old stack
# behavior, which does NOT use a window (100 by default) to reduce
# eval episode returns.
"metrics_num_episodes_for_smoothing": 4,
},
)
)

Expand Down
Loading