diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 6fcc10a1f614..c33d8b43d2f4 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -685,7 +685,6 @@ def setup(self, config: AlgorithmConfig) -> None: ) and self.config.input_ != "sampler" and self.config.enable_rl_module_and_learner - and self.config.enable_env_runner_and_connector_v2 ) else self.config.num_env_runners ), @@ -707,7 +706,6 @@ def setup(self, config: AlgorithmConfig) -> None: ) and self.config.input_ != "sampler" and self.config.enable_rl_module_and_learner - and self.config.enable_env_runner_and_connector_v2 ): from ray.rllib.offline.offline_data import OfflineData @@ -797,20 +795,10 @@ def setup(self, config: AlgorithmConfig) -> None: method_config["type"] = method_type if self.config.enable_rl_module_and_learner: - if self.config.enable_env_runner_and_connector_v2: - module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( - spaces=self.env_runner_group.get_spaces(), - inference_only=False, - ) - # TODO (Sven): Deprecate this path: Old stack API RolloutWorkers and - # DreamerV3's EnvRunners have a `multi_rl_module_spec` property. - elif hasattr(self.env_runner, "multi_rl_module_spec"): - module_spec: MultiRLModuleSpec = self.env_runner.multi_rl_module_spec - else: - raise AttributeError( - "Your local EnvRunner/RolloutWorker does NOT have any property " - "referring to its RLModule!" - ) + module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( + spaces=self.env_runner_group.get_spaces(), + inference_only=False, + ) self.learner_group = self.config.build_learner_group( rl_module_spec=module_spec ) @@ -829,36 +817,19 @@ def setup(self, config: AlgorithmConfig) -> None: rl_module_ckpt_dirs=rl_module_ckpt_dirs, ) - # Only when using RolloutWorkers: Update also the worker set's - # `is_policy_to_train`. - # Note that with the new EnvRunner API in combination with the new stack, - # this information only needs to be kept in the Learner and not on the - # EnvRunners anymore. - if not self.config.enable_env_runner_and_connector_v2: - policies_to_train = self.config.policies_to_train or set( - self.config.policies - ) - self.env_runner_group.foreach_worker( - lambda w: w.set_is_policy_to_train(policies_to_train), - ) - # Sync the weights from the learner group to the rollout workers. - self.env_runner.set_weights(self.learner_group.get_weights()) - self.env_runner_group.sync_weights(inference_only=True) - # New stack/EnvRunner APIs: Use get/set_state. - else: - # Sync the weights from the learner group to the EnvRunners. - rl_module_state = self.learner_group.get_state( - components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, - inference_only=True, - )[COMPONENT_LEARNER][COMPONENT_RL_MODULE] - self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state}) - self.env_runner_group.sync_env_runner_states( - config=self.config, - env_steps_sampled=self.metrics.peek( - NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 - ), - rl_module_state=rl_module_state, - ) + # Sync the weights from the learner group to the EnvRunners. + rl_module_state = self.learner_group.get_state( + components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, + inference_only=True, + )[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state}) + self.env_runner_group.sync_env_runner_states( + config=self.config, + env_steps_sampled=self.metrics.peek( + NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 + ), + rl_module_state=rl_module_state, + ) if self.offline_data: # If the learners are remote we need to provide specific @@ -1716,53 +1687,45 @@ def training_step(self) -> ResultDict: "code and delete this error message)." ) - # Collect SampleBatches from sample workers until we have a full batch. + # Collect a list of Episodes from EnvRunners until we reach the train batch + # size. with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)): if self.config.count_steps_by == "agent_steps": - train_batch, env_runner_results = synchronous_parallel_sample( + episodes, env_runner_results = synchronous_parallel_sample( worker_set=self.env_runner_group, max_agent_steps=self.config.total_train_batch_size, sample_timeout_s=self.config.sample_timeout_s, - _uses_new_env_runners=( - self.config.enable_env_runner_and_connector_v2 - ), + _uses_new_env_runners=True, _return_metrics=True, ) else: - train_batch, env_runner_results = synchronous_parallel_sample( + episodes, env_runner_results = synchronous_parallel_sample( worker_set=self.env_runner_group, max_env_steps=self.config.total_train_batch_size, sample_timeout_s=self.config.sample_timeout_s, - _uses_new_env_runners=( - self.config.enable_env_runner_and_connector_v2 - ), + _uses_new_env_runners=True, _return_metrics=True, ) - train_batch = train_batch.as_multi_agent() - # Reduce EnvRunner metrics over the n EnvRunners. self.metrics.merge_and_log_n_dicts(env_runner_results, key=ENV_RUNNER_RESULTS) - # Only train if train_batch is not empty. - # In an extreme situation, all rollout workers die during the - # synchronous_parallel_sample() call above. - # In which case, we should skip training, wait a little bit, then probe again. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): - if train_batch.agent_steps() > 0: - learner_results = self.learner_group.update_from_batch( - batch=train_batch - ) - self.metrics.log_dict(learner_results, key=LEARNER_RESULTS) - else: - # Wait 1 sec before probing again via weight syncing. - time.sleep(1.0) + learner_results = self.learner_group.update_from_episodes( + episodes=episodes, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) + ), + }, + ) + self.metrics.log_dict(learner_results, key=LEARNER_RESULTS) # Update weights - after learning on the local worker - on all # remote workers (only those RLModules that were actually trained). with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): self.env_runner_group.sync_weights( from_worker_or_learner_group=self.learner_group, - policies=set(learner_results.keys()) - {ALL_MODULES}, + policies=list(set(learner_results.keys()) - {ALL_MODULES}), inference_only=True, ) @@ -2492,29 +2455,6 @@ def add_policy( module_spec=module_spec, ) - # If Learner API is enabled, we need to also add the underlying module - # to the learner group. - if add_to_learners and self.config.enable_rl_module_and_learner: - policy = self.get_policy(policy_id) - module = policy.model - self.learner_group.add_module( - module_id=policy_id, - module_spec=RLModuleSpec.from_module(module), - ) - - # Update each Learner's `policies_to_train` information, but only - # if the arg is explicitly provided here. - if policies_to_train is not None: - self.learner_group.foreach_learner( - func=lambda learner: learner.config.multi_agent( - policies_to_train=policies_to_train - ), - timeout_seconds=0.0, # fire-and-forget - ) - - weights = policy.get_weights() - self.learner_group.set_weights({policy_id: weights}) - # Add to evaluation workers, if necessary. if add_to_eval_env_runners is True and self.eval_env_runner_group is not None: self.eval_env_runner_group.add_policy( @@ -2598,20 +2538,6 @@ def fn(worker): if remove_from_env_runners: self.env_runner_group.foreach_worker(fn, local_env_runner=True) - # Update each Learner's `policies_to_train` information, but only - # if the arg is explicitly provided here. - if ( - remove_from_learners - and self.config.enable_rl_module_and_learner - and policies_to_train is not None - ): - self.learner_group.foreach_learner( - func=lambda learner: learner.config.multi_agent( - policies_to_train=policies_to_train - ), - timeout_seconds=0.0, # fire-and-forget - ) - # Update the evaluation worker set's workers, if required. if remove_from_eval_env_runners and self.eval_env_runner_group is not None: self.eval_env_runner_group.foreach_worker(fn, local_env_runner=True) @@ -2705,10 +2631,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None: """ # New API stack: Delegate to the `Checkpointable` implementation of # `save_to_path()`. - if ( - self.config.enable_rl_module_and_learner - and self.config.enable_env_runner_and_connector_v2 - ): + if self.config.enable_rl_module_and_learner: return self.save_to_path(checkpoint_dir) checkpoint_dir = pathlib.Path(checkpoint_dir) @@ -2770,10 +2693,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None: def load_checkpoint(self, checkpoint_dir: str) -> None: # New API stack: Delegate to the `Checkpointable` implementation of # `restore_from_path()`. - if ( - self.config.enable_rl_module_and_learner - and self.config.enable_env_runner_and_connector_v2 - ): + if self.config.enable_rl_module_and_learner: self.restore_from_path(checkpoint_dir) # Call the `on_checkpoint_loaded` callback. diff --git a/rllib/connectors/connector_v2.py b/rllib/connectors/connector_v2.py index 842cd6e01b86..83eada4ba87f 100644 --- a/rllib/connectors/connector_v2.py +++ b/rllib/connectors/connector_v2.py @@ -320,7 +320,7 @@ def single_agent_episode_iterator( list_indices = defaultdict(int) # Single-agent case. - if isinstance(episodes[0], SingleAgentEpisode): + if episodes and isinstance(episodes[0], SingleAgentEpisode): if zip_with_batch_column is not None: if len(zip_with_batch_column) != len(episodes): raise ValueError( diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py index 4f35e4b7b076..3071f2d6ba11 100644 --- a/rllib/core/rl_module/marl_module.py +++ b/rllib/core/rl_module/marl_module.py @@ -1,15 +1,5 @@ from ray.rllib.utils.deprecation import deprecation_warning -from ray.rllib.core.rl_module.multi_rl_module import ( - MultiRLModule, - MultiRLModuleSpec, - MultiRLModuleConfig, -) - - -MultiAgentRLModule = MultiRLModule -MultiAgentRLModuleConfig = MultiRLModuleConfig -MultiAgentRLModuleSpec = MultiRLModuleSpec deprecation_warning( old="ray.rllib.core.rl_module.marl_module", diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index c032c69682ca..dc3ff9f23ce6 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -77,7 +77,7 @@ ) stop = { - f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0, + f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0, TRAINING_ITERATION: 350, } diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 59b828321992..e33d769e26e9 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -6,9 +6,8 @@ LEARNER_RESULTS = "learners" FAULT_TOLERANCE_STATS = "fault_tolerance" TIMERS = "timers" -# ALGORITHM_RESULTS = "algorithm" -# RLModule metrics +# RLModule metrics. NUM_TRAINABLE_PARAMETERS = "num_trainable_parameters" NUM_NON_TRAINABLE_PARAMETERS = "num_non_trainable_parameters" diff --git a/rllib/utils/serialization.py b/rllib/utils/serialization.py index a3b6975c00b2..63aa7fdaa165 100644 --- a/rllib/utils/serialization.py +++ b/rllib/utils/serialization.py @@ -87,6 +87,8 @@ def gym_space_to_dict(space: gym.spaces.Space) -> Dict: Returns: Serialized JSON string. """ + if space is None: + return None def _box(sp: gym.spaces.Box) -> Dict: return { diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 9b43041c2a86..7ffe87469557 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -1,19 +1,13 @@ import argparse from collections import Counter import copy -import gymnasium as gym -from gymnasium.spaces import Box, Discrete, MultiDiscrete, MultiBinary -from gymnasium.spaces import Dict as GymDict -from gymnasium.spaces import Tuple as GymTuple import json import logging -import numpy as np import os import pprint import random import re import time -import tree # pip install dm_tree from typing import ( TYPE_CHECKING, Any, @@ -26,11 +20,19 @@ ) import yaml +import gymnasium as gym +from gymnasium.spaces import Box, Discrete, MultiDiscrete, MultiBinary +from gymnasium.spaces import Dict as GymDict +from gymnasium.spaces import Tuple as GymTuple +import numpy as np +import tree # pip install dm_tree + import ray from ray import air, tune from ray.air.constants import TRAINING_ITERATION from ray.air.integrations.wandb import WandbLoggerCallback, WANDB_ENV_VAR from ray.rllib.common import SupportedFileType +from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind from ray.rllib.train import load_experiments_from_file from ray.rllib.utils.annotations import OldAPIStack @@ -789,7 +791,6 @@ def check_train_results_new_api_stack(train_results: ResultDict) -> None: data in it. """ # Import these here to avoid circular dependencies. - from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, FAULT_TOLERANCE_STATS,