Skip to content

Commit

Permalink
[RLlib] Fix DreamerV3 bug for num_env_runners > 0 (users should use…
Browse files Browse the repository at this point in the history
… `num_envs_per_env_runner > 1` instead, though). (ray-project#45819)

Signed-off-by: Richard Liu <[email protected]>
  • Loading branch information
sven1977 authored and richardsliu committed Jun 12, 2024
1 parent 3f094d4 commit 0b846e3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
6 changes: 6 additions & 0 deletions rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def __init__(self, algo_class=None):
# Dreamer only runs on the new API stack.
self.enable_rl_module_and_learner = True
self.enable_env_runner_and_connector_v2 = True
# TODO (sven): DreamerV3 still uses its own EnvRunner class. This env-runner
# does not use connectors. We therefore should not attempt to merge/broadcast
# the connector states between EnvRunners (if >0). Note that this is only
# relevant if num_env_runners > 0, which is normally not the case when using
# this algo.
self.use_worker_filter_stats = False
# __sphinx_doc_end__
# fmt: on

Expand Down
5 changes: 3 additions & 2 deletions rllib/algorithms/dreamerv3/tests/test_dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_dreamerv3_compilation(self):
config = (
dreamerv3.DreamerV3Config()
.framework(eager_tracing=False)
.env_runners(num_env_runners=2)
.training(
# Keep things simple. Especially the long dream rollouts seem
# to take an enormous amount of time (initially).
Expand All @@ -52,13 +53,13 @@ def test_dreamerv3_compilation(self):
use_float16=False,
)
.learners(
num_learners=0, # TODO 2 # Try with 2 Learners.
num_learners=2, # Try with 2 Learners.
num_cpus_per_learner=1,
num_gpus_per_learner=0,
)
)

num_iterations = 2
num_iterations = 3

for env in [
"FrozenLake-v1",
Expand Down
12 changes: 10 additions & 2 deletions rllib/algorithms/dreamerv3/utils/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,16 @@ def get_metrics(self) -> ResultDict:
# Return reduced metrics.
return self.metrics.reduce()

# TODO (sven): Remove the requirement for EnvRunners/RolloutWorkers to have this
# API. Replace by proper state overriding via `EnvRunner.set_state()`
def get_weights(self, policies, inference_only):
"""Returns the weights of our (single-agent) RLModule."""
if self.module is None:
assert self.config.share_module_between_env_runner_and_learner
return {}
else:
return {
DEFAULT_MODULE_ID: self.module.get_state(inference_only=inference_only),
}

def set_weights(self, weights, global_vars=None):
"""Writes the weights of our (single-agent) RLModule."""
if self.module is None:
Expand Down
15 changes: 8 additions & 7 deletions rllib/env/env_runner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def sync_env_runner_states(
}
# Ignore states from remote EnvRunners (use the current `from_worker` states
# only).
else:
elif hasattr(from_worker, "_env_to_module"):
env_runner_states["connector_states"] = {
"env_to_module_states": from_worker._env_to_module.get_state(),
"module_to_env_states": from_worker._module_to_env.get_state(),
Expand All @@ -464,12 +464,13 @@ def sync_env_runner_states(

def _update(_env_runner: EnvRunner) -> Any:
env_runner_states = ray.get(ref_env_runner_states)
_env_runner._env_to_module.set_state(
env_runner_states["connector_states"]["env_to_module_states"]
)
_env_runner._module_to_env.set_state(
env_runner_states["connector_states"]["module_to_env_states"]
)
if hasattr(_env_runner, "_env_to_module"):
_env_runner._env_to_module.set_state(
env_runner_states["connector_states"]["env_to_module_states"]
)
_env_runner._module_to_env.set_state(
env_runner_states["connector_states"]["module_to_env_states"]
)
# Update the global number of environment steps for each worker.
if "env_steps_sampled" in env_runner_states:
# _env_runner.global_num_env_steps_sampled =
Expand Down

0 comments on commit 0b846e3

Please sign in to comment.