From a4ceeb46d4b6cfd5a3b68f6b88a638c51ca9ff9d Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sun, 22 May 2022 17:12:02 +0200 Subject: [PATCH 1/2] wip --- rllib/agents/dqn/__init__.py | 15 +- rllib/agents/dqn/apex.py | 214 ++++++++++++++---------- rllib/agents/dqn/r2d2.py | 203 ++++++++++++++++------ rllib/agents/dqn/tests/test_apex_dqn.py | 128 ++++++++------ rllib/agents/dqn/tests/test_r2d2.py | 34 ++-- rllib/agents/trainer_config.py | 5 +- rllib/algorithms/dqn/__init__.py | 3 +- rllib/algorithms/dqn/dqn.py | 38 +++-- 8 files changed, 405 insertions(+), 235 deletions(-) diff --git a/rllib/agents/dqn/__init__.py b/rllib/agents/dqn/__init__.py index 1326e0d39a5b..9b5097bec9b9 100644 --- a/rllib/agents/dqn/__init__.py +++ b/rllib/agents/dqn/__init__.py @@ -1,5 +1,5 @@ -from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG -from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG +from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG +from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG @@ -13,20 +13,23 @@ from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy __all__ = [ + "ApexConfig", "ApexTrainer", - "APEX_DEFAULT_CONFIG", + "DQNConfig", "DQNTFPolicy", "DQNTorchPolicy", "DQNTrainer", - "DEFAULT_CONFIG", "R2D2TorchPolicy", "R2D2Trainer", - "R2D2_DEFAULT_CONFIG", - "SIMPLE_Q_DEFAULT_CONFIG", "SimpleQConfig", "SimpleQTFPolicy", "SimpleQTorchPolicy", "SimpleQTrainer", + # Deprecated. + "APEX_DEFAULT_CONFIG", + "DEFAULT_CONFIG", + "R2D2_DEFAULT_CONFIG", + "SIMPLE_Q_DEFAULT_CONFIG", ] from ray.rllib.utils.deprecation import deprecation_warning diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 36c6acd91913..10e414a967a1 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -16,31 +16,25 @@ import copy import platform import random -from typing import Tuple, Dict, List, DefaultDict, Set +from typing import Dict, List, DefaultDict, Set import ray from ray.actor import ActorHandle -from ray.rllib import RolloutWorker from ray.rllib.agents import Trainer -from ray.rllib.algorithms.dqn.dqn import ( - DEFAULT_CONFIG as DQN_DEFAULT_CONFIG, - DQNTrainer, -) +from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer from ray.rllib.algorithms.dqn.learner_thread import LearnerThread -from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.common import ( STEPS_TRAINED_COUNTER, STEPS_TRAINED_THIS_ITER_COUNTER, - _get_global_vars, - _get_shared_metrics, ) from ray.rllib.execution.parallel_requests import ( asynchronous_parallel_requests, wait_asynchronous_requests, ) -from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, NUM_AGENT_STEPS_SAMPLED, @@ -53,7 +47,6 @@ TARGET_NET_UPDATE_TIMER, ) from ray.rllib.utils.typing import ( - SampleBatchType, TrainerConfigDict, ResultDict, PartialTrainerConfigDict, @@ -61,31 +54,96 @@ ) from ray.tune.trainable import Trainable from ray.tune.utils.placement_groups import PlacementGroupFactory -from ray.rllib.utils.deprecation import DEPRECATED_VALUE - -# fmt: off -# __sphinx_doc_begin__ -APEX_DEFAULT_CONFIG = merge_dicts( - # See also the options in dqn.py, which are also supported. - DQN_DEFAULT_CONFIG, - { - "optimizer": merge_dicts( - DQN_DEFAULT_CONFIG["optimizer"], { +from ray.util.ml_utils.dict import merge_dicts + + +class ApexConfig(DQNConfig): + """Defines a configuration class from which an ApexTrainer can be built. + + Example: + >>> from ray.rllib.agents.dqn.apex import ApexConfig + >>> config = ApexConfig() + >>> print(config.replay_buffer_config) + >>> replay_config = config.replay_buffer_config.update( + >>> { + >>> "capacity": 100000, + >>> "prioritized_replay_alpha": 0.45, + >>> "prioritized_replay_beta": 0.55, + >>> "prioritized_replay_eps": 3e-6, + >>> } + >>> ) + >>> config.training(replay_buffer_config=replay_config)\ + >>> .resources(num_gpus=1)\ + >>> .rollouts(num_rollout_workers=30)\ + >>> .environment("CartPole-v1") + >>> trainer = ApexTrainer(config=config) + >>> while True: + >>> trainer.train() + + Example: + >>> from ray.rllib.agents.dqn.apex import ApexConfig + >>> from ray import tune + >>> config = ApexConfig() + >>> config.training(num_atoms=tune.grid_search(list(range(1, 11))) + >>> config.environment(env="CartPole-v1") + >>> tune.run( + >>> "APEX", + >>> stop={"episode_reward_mean":200}, + >>> config=config.to_dict() + >>> ) + + Example: + >>> from ray.rllib.agents.dqn.apex import ApexConfig + >>> config = ApexConfig() + >>> print(config.exploration_config) + >>> explore_config = config.exploration_config.update( + >>> { + >>> "type": "EpsilonGreedy", + >>> "initial_epsilon": 0.96, + >>> "final_epsilon": 0.01, + >>> "epsilone_timesteps": 5000, + >>> } + >>> ) + >>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\ + >>> .exploration(exploration_config=explore_config) + + Example: + >>> from ray.rllib.agents.dqn.apex import ApexConfig + >>> config = ApexConfig() + >>> print(config.exploration_config) + >>> explore_config = config.exploration_config.update( + >>> { + >>> "type": "SoftQ", + >>> "temperature": [1.0], + >>> } + >>> ) + >>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\ + >>> .exploration(exploration_config=explore_config) + """ + + def __init__(self, trainer_class=None): + """Initializes a ApexConfig instance.""" + super().__init__(trainer_class=trainer_class or ApexTrainer) + + # fmt: off + # __sphinx_doc_begin__ + # APEX-DQN settings overriding DQN ones: + # .training() + self.optimizer = merge_dicts( + DQNConfig().optimizer, { "max_weight_sync_delay": 400, "num_replay_buffer_shards": 4, "debug": False - }), - "n_step": 3, - "num_gpus": 1, - "num_workers": 32, - - # TODO(jungong) : add proper replay_buffer_config after - # DistributedReplayBuffer type is supported. - "replay_buffer_config": { + }) + self.n_step = 3 + self.train_batch_size = 512 + self.target_network_update_freq = 500000 + self.training_intensity = 1 + # APEX-DQN is using a distributed (non local) replay buffer. + self.replay_buffer_config = { "no_local_replay_buffer": True, # Specify prioritized replay by supplying a buffer type that supports # prioritization - "prioritized_replay": DEPRECATED_VALUE, "type": "MultiAgentPrioritizedReplayBuffer", "capacity": 2000000, "replay_batch_size": 32, @@ -106,63 +164,26 @@ # on which the learner is located. "replay_buffer_shards_colocated_with_driver": True, "worker_side_prioritization": True, - }, - - "train_batch_size": 512, - "rollout_fragment_length": 50, - # Update the target network every `target_network_update_freq` sample timesteps. - "target_network_update_freq": 500000, - # Minimum env sampling timesteps to accumulate within a single `train()` call. - # This value does not affect learning, only the number of times - # `Trainer.step_attempt()` is called by `Trainer.train()`. If - after one - # `step_attempt()`, the env sampling timestep count has not been reached, will - # perform n more `step_attempt()` calls until the minimum timesteps have been - # executed. Set to 0 for no minimum timesteps. - "min_sample_timesteps_per_reporting": 25000, - "exploration_config": {"type": "PerWorkerEpsilonGreedy"}, - "min_time_s_per_reporting": 30, - # This will set the ratio of replayed from a buffer and learned - # on timesteps to sampled from an environment and stored in the replay - # buffer timesteps. Must be greater than 0. - # TODO: Find a way to support None again as a means to replay - # proceeding as fast as possible. - "training_intensity": 1, - }, -) -# __sphinx_doc_end__ -# fmt: on - - -# Update worker weights as they finish generating experiences. -class UpdateWorkerWeights: - def __init__( - self, - learner_thread: LearnerThread, - workers: WorkerSet, - max_weight_sync_delay: int, - ): - self.learner_thread = learner_thread - self.workers = workers - self.steps_since_update = defaultdict(int) - self.max_weight_sync_delay = max_weight_sync_delay - self.weights = None - - def __call__(self, item: Tuple[ActorHandle, SampleBatchType]): - actor, batch = item - self.steps_since_update[actor] += batch.count - if self.steps_since_update[actor] >= self.max_weight_sync_delay: - # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors. - if self.weights is None or self.learner_thread.weights_updated: - self.learner_thread.weights_updated = False - self.weights = ray.put(self.workers.local_worker().get_weights()) - actor.set_weights.remote(self.weights, _get_global_vars()) - # Also update global vars of the local worker. - self.workers.local_worker().set_global_vars(_get_global_vars()) - self.steps_since_update[actor] = 0 - # Update metrics. - metrics = _get_shared_metrics() - metrics.counters["num_weight_syncs"] += 1 + # Deprecated key. + "prioritized_replay": DEPRECATED_VALUE, + } + + # .rollouts() + self.num_workers = 32 + self.rollout_fragment_length = 50 + self.exploration_config = { + "type": "PerWorkerEpsilonGreedy", + } + + # .resources() + self.num_gpus = 1 + + # .reporting() + self.min_time_s_per_reporting = 30 + self.min_sample_timesteps_per_reporting = 25000 + + # fmt: on + # __sphinx_doc_end__ class ApexTrainer(DQNTrainer): @@ -232,7 +253,7 @@ def setup(self, config: PartialTrainerConfigDict): @classmethod @override(DQNTrainer) def get_default_config(cls) -> TrainerConfigDict: - return APEX_DEFAULT_CONFIG + return ApexConfig().to_dict() @override(DQNTrainer) def validate_config(self, config): @@ -548,3 +569,20 @@ def default_resource_request(cls, config): ), strategy=config.get("placement_strategy", "PACK"), ) + + +# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(ApexConfig().to_dict()) + + @Deprecated( + old="ray.rllib.agents.dqn.apex.APEX_DEFAULT_CONFIG", + new="ray.rllib.agents.dqn.apex.ApexConfig(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +APEX_DEFAULT_CONFIG = _deprecated_default_config() diff --git a/rllib/agents/dqn/r2d2.py b/rllib/agents/dqn/r2d2.py index b006c961b685..9f7d99734ede 100644 --- a/rllib/agents/dqn/r2d2.py +++ b/rllib/agents/dqn/r2d2.py @@ -1,37 +1,99 @@ import logging -from typing import Type +from typing import Optional, Type -from ray.rllib.algorithms.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_DEFAULT_CONFIG +from ray.rllib.algorithms.dqn import DQNConfig, DQNTrainer from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy -from ray.rllib.agents.trainer import Trainer from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.deprecation import DEPRECATED_VALUE logger = logging.getLogger(__name__) -# fmt: off -# __sphinx_doc_begin__ -R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs( - DQN_DEFAULT_CONFIG, # See keys in dqn.py, which are also supported. - { - # Learning rate for adam optimizer. - "lr": 1e-4, - # Discount factor. - "gamma": 0.997, - # Train batch size (in number of single timesteps). - "train_batch_size": 64, - # Adam epsilon hyper parameter - "adam_epsilon": 1e-3, - # Run in parallel by default. - "num_workers": 2, - # Batch mode must be complete_episodes. - "batch_mode": "complete_episodes", - - # === Replay buffer === - "replay_buffer_config": { + +class R2D2Config(DQNConfig): + """Defines a configuration class from which a R2D2Trainer can be built. + + Example: + >>> from ray.rllib.agents.dqn.r2d2 import R2D2Config + >>> config = R2D2Config() + >>> print(config.h_function_epsilon) + >>> replay_config = config.replay_buffer_config.update( + >>> { + >>> "capacity": 1000000, + >>> "replay_burn_in": 20, + >>> } + >>> ) + >>> config.training(replay_buffer_config=replay_config)\ + >>> .resources(num_gpus=1)\ + >>> .rollouts(num_rollout_workers=30)\ + >>> .environment("CartPole-v1") + >>> trainer = R2D2Trainer(config=config) + >>> while True: + >>> trainer.train() + + Example: + >>> from ray.rllib.agents.dqn.r2d2 import R2D2Config + >>> from ray import tune + >>> config = R2D2Config() + >>> config.training(train_batch_size=tune.grid_search([256, 64]) + >>> config.environment(env="CartPole-v1") + >>> tune.run( + >>> "R2D2", + >>> stop={"episode_reward_mean":200}, + >>> config=config.to_dict() + >>> ) + + Example: + >>> from ray.rllib.agents.dqn.r2d2 import R2D2Config + >>> config = R2D2Config() + >>> print(config.exploration_config) + >>> explore_config = config.exploration_config.update( + >>> { + >>> "initial_epsilon": 1.0, + >>> "final_epsilon": 0.1, + >>> "epsilone_timesteps": 200000, + >>> } + >>> ) + >>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\ + >>> .exploration(exploration_config=explore_config) + + Example: + >>> from ray.rllib.agents.dqn.r2d2 import R2D2Config + >>> config = R2D2Config() + >>> print(config.exploration_config) + >>> explore_config = config.exploration_config.update( + >>> { + >>> "type": "SoftQ", + >>> "temperature": [1.0], + >>> } + >>> ) + >>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\ + >>> .exploration(exploration_config=explore_config) + """ + + def __init__(self, trainer_class=None): + """Initializes a ApexConfig instance.""" + super().__init__(trainer_class=trainer_class or R2D2Trainer) + + # fmt: off + # __sphinx_doc_begin__ + # R2D2-specific settings: + self.zero_init_states = True + self.use_h_function = True + self.h_function_epsilon = 1e-3 + + # R2D2 settings overriding DQN ones: + # .training() + self.adam_epsilon = 1e-3 + self.lr = 1e-4 + self.gamma = 0.997 + self.train_batch_size = 64 + self.target_network_update_freq = 2500 + # R2D2 is using a buffer that stores sequences. + self.replay_buffer_config = { "type": "MultiAgentReplayBuffer", # Specify prioritized replay by supplying a buffer type that supports # prioritization, for example: MultiAgentPrioritizedReplayBuffer. @@ -53,34 +115,54 @@ # used for loss calculation is `n - replay_burn_in` time steps # (n=LSTM’s/attention net’s max_seq_len). "replay_burn_in": 0, - }, - # If True, assume a zero-initialized state input (no matter where in - # the episode the sequence is located). - # If False, store the initial states along with each SampleBatch, use - # it (as initial state when running through the network for training), - # and update that initial state during training (from the internal - # state outputs of the immediately preceding sequence). - "zero_init_states": True, - - # Whether to use the h-function from the paper [1] to scale target - # values in the R2D2-loss function: - # h(x) = sign(x)(􏰅|x| + 1 − 1) + εx - "use_h_function": True, - # The epsilon parameter from the R2D2 loss function (only used - # if `use_h_function`=True. - "h_function_epsilon": 1e-3, - - # Update the target network every `target_network_update_freq` sample steps. - "target_network_update_freq": 2500, - - # Deprecated keys: - # Use config["replay_buffer_config"]["replay_burn_in"] instead - "burn_in": DEPRECATED_VALUE - }, - _allow_unknown_configs=True, -) -# __sphinx_doc_end__ -# fmt: on + } + + # .rollouts() + self.num_workers = 2 + self.batch_mode = "complete_episodes" + + # fmt: on + # __sphinx_doc_end__ + + self.burn_in = DEPRECATED_VALUE + + def training( + self, + *, + zero_init_states: Optional[bool] = None, + use_h_function: Optional[bool] = None, + h_function_epsilon: Optional[float] = None, + **kwargs, + ) -> "R2D2Config": + """Sets the training related configuration. + + Args: + zero_init_states: If True, assume a zero-initialized state input (no + matter where in the episode the sequence is located). + If False, store the initial states along with each SampleBatch, use + it (as initial state when running through the network for training), + and update that initial state during training (from the internal + state outputs of the immediately preceding sequence). + use_h_function: Whether to use the h-function from the paper [1] to scale + target values in the R2D2-loss function: + h(x) = sign(x)(􏰅|x| + 1 − 1) + εx + h_function_epsilon: The epsilon parameter from the R2D2 loss function (only + used if `use_h_function`=True. + + Returns: + This updated TrainerConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if zero_init_states is not None: + self.zero_init_states = zero_init_states + if use_h_function is not None: + self.use_h_function = use_h_function + if h_function_epsilon is not None: + self.h_function_epsilon = h_function_epsilon + + return self # Build an R2D2 trainer, which uses the framework specific Policy @@ -103,7 +185,7 @@ class R2D2Trainer(DQNTrainer): @classmethod @override(DQNTrainer) def get_default_config(cls) -> TrainerConfigDict: - return R2D2_DEFAULT_CONFIG + return R2D2Config().to_dict() @override(DQNTrainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: @@ -136,3 +218,20 @@ def validate_config(self, config: TrainerConfigDict) -> None: if config.get("batch_mode") != "complete_episodes": raise ValueError("`batch_mode` must be 'complete_episodes'!") + + +# Deprecated: Use ray.rllib.agents.dqn.r2d2.R2D2Config instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(R2D2Config().to_dict()) + + @Deprecated( + old="ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG", + new="ray.rllib.agents.dqn.r2d2.R2D2Config(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +R2D2_DEFAULT_CONFIG = _deprecated_default_config() diff --git a/rllib/agents/dqn/tests/test_apex_dqn.py b/rllib/agents/dqn/tests/test_apex_dqn.py index c5dfcd4d5d1f..5a428fde224f 100644 --- a/rllib/agents/dqn/tests/test_apex_dqn.py +++ b/rllib/agents/dqn/tests/test_apex_dqn.py @@ -21,17 +21,26 @@ def tearDown(self): ray.shutdown() def test_apex_zero_workers(self): - config = apex.APEX_DEFAULT_CONFIG.copy() - config["num_workers"] = 0 - config["num_gpus"] = 0 - config["replay_buffer_config"] = { - "learning_starts": 1000, - } - config["min_sample_timesteps_per_reporting"] = 100 - config["min_time_s_per_reporting"] = 1 - config["optimizer"]["num_replay_buffer_shards"] = 1 + config = ( + apex.ApexConfig() + .rollouts(num_rollout_workers=0) + .resources(num_gpus=0) + .training( + replay_buffer_config={ + "learning_starts": 1000, + }, + optimizer={ + "num_replay_buffer_shards": 1, + }, + ) + .reporting( + min_sample_timesteps_per_reporting=100, + min_time_s_per_reporting=1, + ) + ) + for _ in framework_iterator(config): - trainer = apex.ApexTrainer(config=config, env="CartPole-v0") + trainer = config.build(env="CartPole-v0") results = trainer.train() check_train_results(results) print(results) @@ -39,19 +48,26 @@ def test_apex_zero_workers(self): def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): """Test whether an APEX-DQNTrainer can be built on all frameworks.""" - config = apex.APEX_DEFAULT_CONFIG.copy() - config["num_workers"] = 3 - config["num_gpus"] = 0 - config["replay_buffer_config"] = { - "learning_starts": 1000, - } - config["min_sample_timesteps_per_reporting"] = 100 - config["min_time_s_per_reporting"] = 1 - config["optimizer"]["num_replay_buffer_shards"] = 1 + config = ( + apex.ApexConfig() + .rollouts(num_rollout_workers=3) + .resources(num_gpus=0) + .training( + replay_buffer_config={ + "learning_starts": 1000, + }, + optimizer={ + "num_replay_buffer_shards": 1, + }, + ) + .reporting( + min_sample_timesteps_per_reporting=100, + min_time_s_per_reporting=1, + ) + ) for _ in framework_iterator(config, with_eager_tracing=True): - plain_config = config.copy() - trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0") + trainer = config.build(env="CartPole-v0") # Test per-worker epsilon distribution. infos = trainer.workers.foreach_policy( @@ -77,37 +93,43 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): trainer.stop() def test_apex_lr_schedule(self): - config = apex.APEX_DEFAULT_CONFIG.copy() - config["num_workers"] = 1 - config["num_gpus"] = 0 - config["train_batch_size"] = 10 - config["rollout_fragment_length"] = 5 - config["replay_buffer_config"] = { - "no_local_replay_buffer": True, - "type": "MultiAgentPrioritizedReplayBuffer", - "learning_starts": 10, - "capacity": 100, - "replay_batch_size": 10, - "prioritized_replay_alpha": 0.6, - # Beta parameter for sampling from prioritized replay buffer. - "prioritized_replay_beta": 0.4, - # Epsilon to add to the TD errors when updating priorities. - "prioritized_replay_eps": 1e-6, - } - config["min_sample_timesteps_per_reporting"] = 10 - # 0 metrics reporting delay, this makes sure timestep, - # which lr depends on, is updated after each worker rollout. - config["min_time_s_per_reporting"] = 0 - config["optimizer"]["num_replay_buffer_shards"] = 1 - # This makes sure learning schedule is checked every 10 timesteps. - config["optimizer"]["max_weight_sync_delay"] = 10 - # Initial lr, doesn't really matter because of the schedule below. - config["lr"] = 0.2 - lr_schedule = [ - [0, 0.2], - [100, 0.001], - ] - config["lr_schedule"] = lr_schedule + config = ( + apex.ApexConfig() + .rollouts( + num_rollout_workers=1, + rollout_fragment_length=5, + ) + .resources(num_gpus=0) + .training( + train_batch_size=10, + optimizer={ + "num_replay_buffer_shards": 1, + # This makes sure learning schedule is checked every 10 timesteps. + "max_weight_sync_delay": 10, + }, + replay_buffer_config={ + "no_local_replay_buffer": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "learning_starts": 10, + "capacity": 100, + "replay_batch_size": 10, + "prioritized_replay_alpha": 0.6, + # Beta parameter for sampling from prioritized replay buffer. + "prioritized_replay_beta": 0.4, + # Epsilon to add to the TD errors when updating priorities. + "prioritized_replay_eps": 1e-6, + }, + # Initial lr, doesn't really matter because of the schedule below. + lr=0.2, + lr_schedule=[[0, 0.2], [100, 0.001]], + ) + .reporting( + min_sample_timesteps_per_reporting=10, + # 0 metrics reporting delay, this makes sure timestep, + # which lr depends on, is updated after each worker rollout. + min_time_s_per_reporting=0, + ) + ) def _step_n_times(trainer, n: int): """Step trainer n times. @@ -122,7 +144,7 @@ def _step_n_times(trainer, n: int): ] for _ in framework_iterator(config): - trainer = apex.ApexTrainer(config=config, env="CartPole-v0") + trainer = config.build(env="CartPole-v0") lr = _step_n_times(trainer, 5) # 50 timesteps # Close to 0.2 diff --git a/rllib/agents/dqn/tests/test_r2d2.py b/rllib/agents/dqn/tests/test_r2d2.py index 762bb207566d..5d7834474711 100644 --- a/rllib/agents/dqn/tests/test_r2d2.py +++ b/rllib/agents/dqn/tests/test_r2d2.py @@ -47,26 +47,30 @@ def tearDownClass(cls) -> None: def test_r2d2_compilation(self): """Test whether a R2D2Trainer can be built on all frameworks.""" - config = dqn.R2D2_DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. - # Wrap with an LSTM and use a very simple base-model. - config["model"]["use_lstm"] = True - config["model"]["max_seq_len"] = 20 - config["model"]["fcnet_hiddens"] = [32] - config["model"]["lstm_cell_size"] = 64 - - config["replay_buffer_config"]["replay_burn_in"] = 20 - config["zero_init_states"] = True - - config["dueling"] = False - config["lr"] = 5e-4 - config["exploration_config"]["epsilon_timesteps"] = 100000 + config = ( + dqn.r2d2.R2D2Config() + .rollouts(num_rollout_workers=0) + .training( + model={ + # Wrap with an LSTM and use a very simple base-model. + "use_lstm": True, + "max_seq_len": 20, + "fcnet_hiddens": [32], + "lstm_cell_size": 64, + }, + dueling=False, + lr=5e-4, + zero_init_states=True, + replay_buffer_config={"replay_burn_in": 20}, + ) + .exploration(exploration_config={"epsilon_timesteps": 100000}) + ) num_iterations = 1 # Test building an R2D2 agent in all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): - trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0") + trainer = config.build(env="CartPole-v0") for i in range(num_iterations): results = trainer.train() check_train_results(results) diff --git a/rllib/agents/trainer_config.py b/rllib/agents/trainer_config.py index cb3c0f95095b..8d3aedb34f71 100644 --- a/rllib/agents/trainer_config.py +++ b/rllib/agents/trainer_config.py @@ -19,6 +19,7 @@ from ray.rllib.offline.estimators.weighted_importance_sampling import ( WeightedImportanceSampling, ) +from ray.rllib.utils import merge_dicts from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.typing import ( EnvConfigDict, @@ -723,7 +724,7 @@ def training( if model is not None: self.model = model if optimizer is not None: - self.optimizer = optimizer + self.optimizer = merge_dicts(self.optimizer, optimizer) return self @@ -1060,7 +1061,7 @@ def reporting( timestep count has not been reached, will perform n more `step_attempt()` calls until the minimum timesteps have been executed. Set to 0 for no minimum timesteps. - min_sample_timesteps_per_reporting: Minimum env samplingtimesteps to + min_sample_timesteps_per_reporting: Minimum env sampling timesteps to accumulate within a single `train()` call. This value does not affect learning, only the number of times `Trainer.step_attempt()` is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling diff --git a/rllib/algorithms/dqn/__init__.py b/rllib/algorithms/dqn/__init__.py index d9497bf437dd..3ab21bbd365f 100644 --- a/rllib/algorithms/dqn/__init__.py +++ b/rllib/algorithms/dqn/__init__.py @@ -1,5 +1,5 @@ # from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG -from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG +from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy @@ -16,6 +16,7 @@ __all__ = [ "ApexTrainer", "APEX_DEFAULT_CONFIG", + "DQNConfig", "DQNTFPolicy", "DQNTorchPolicy", "DQNTrainer", diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 8091ce267bec..bf2fa0f6bf2d 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -114,15 +114,13 @@ class DQNConfig(SimpleQConfig): >>> .exploration(exploration_config=explore_config) """ - def __init__(self): + def __init__(self, trainer_class=None): """Initializes a DQNConfig instance.""" - super().__init__() + super().__init__(trainer_class=trainer_class or DQNTrainer) - # DQN specific + # DQN specific config settings. # fmt: off # __sphinx_doc_begin__ - # - self.trainer_class = DQNTrainer self.num_atoms = 1 self.v_min = -10.0 self.v_max = 10.0 @@ -213,6 +211,8 @@ def training( -> will make sure that replay+train op will be executed 4x asoften as rollout+insert op (4 * 250 = 1000). See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details. + TODO: Find a way to support None again as a means to replay + proceeding as fast as possible. worker_side_prioritization: Whether to compute priorities on workers. replay_buffer_config: Replay buffer config. Examples: @@ -283,19 +283,7 @@ def training( if replay_buffer_config is not None: self.replay_buffer_config = replay_buffer_config - -# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead! -class _deprecated_default_config(dict): - def __init__(self): - super().__init__(DQNConfig().to_dict()) - - @Deprecated( - old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG", - new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)", - error=False, - ) - def __getitem__(self, item): - return super().__getitem__(item) + return self def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: @@ -437,6 +425,20 @@ def training_iteration(self) -> ResultDict: return train_results +# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(DQNConfig().to_dict()) + + @Deprecated( + old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG", + new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + DEFAULT_CONFIG = _deprecated_default_config() From 55dae3581af9e2029b61241c02aa2ce0ce02b939 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 23 May 2022 10:19:57 +0200 Subject: [PATCH 2/2] wip. --- rllib/algorithms/ddpg/ddpg.py | 4 ++-- rllib/algorithms/pg/pg.py | 4 ++-- rllib/utils/metrics/__init__.py | 4 ++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/rllib/algorithms/ddpg/ddpg.py b/rllib/algorithms/ddpg/ddpg.py index 3456848bd81c..088d109f435a 100644 --- a/rllib/algorithms/ddpg/ddpg.py +++ b/rllib/algorithms/ddpg/ddpg.py @@ -29,8 +29,8 @@ class DDPGConfig(SimpleQConfig): >>> from ray import tune >>> config = DDPGConfig() >>> # Print out some default values. - >>> print(config.lr) - 0.0004 + >>> print(config.lr) # doctest: +SKIP + 0.0004 >>> # Update the config object. >>> config.training(lr=tune.grid_search([0.001, 0.0001])) >>> # Set the config object's env. diff --git a/rllib/algorithms/pg/pg.py b/rllib/algorithms/pg/pg.py index af86bb514c75..74eee106964f 100644 --- a/rllib/algorithms/pg/pg.py +++ b/rllib/algorithms/pg/pg.py @@ -26,8 +26,8 @@ class PGConfig(TrainerConfig): >>> from ray import tune >>> config = PGConfig() >>> # Print out some default values. - >>> print(config.lr) - ... 0.0004 + >>> print(config.lr) # doctest: +SKIP + 0.0004 >>> # Update the config object. >>> config.training(lr=tune.grid_search([0.001, 0.0001])) >>> # Set the config object's env. diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 7cd3411edffa..080bc6b14802 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -1,8 +1,12 @@ # Counters for sampling and training steps (env- and agent steps). NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled" NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled" +NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter" +NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter" NUM_ENV_STEPS_TRAINED = "num_env_steps_trained" NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained" +NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter" +NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter" # Counters to track target network updates. LAST_TARGET_UPDATE_TS = "last_target_update_ts"