Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] APEX-DQN and R2D2 config objects. #25067

Merged
merged 3 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions rllib/agents/dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
Expand All @@ -13,20 +13,23 @@
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy

__all__ = [
"ApexConfig",
"ApexTrainer",
"APEX_DEFAULT_CONFIG",
"DQNConfig",
"DQNTFPolicy",
"DQNTorchPolicy",
"DQNTrainer",
"DEFAULT_CONFIG",
"R2D2TorchPolicy",
"R2D2Trainer",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
"SimpleQConfig",
"SimpleQTFPolicy",
"SimpleQTorchPolicy",
"SimpleQTrainer",
# Deprecated.
"APEX_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
]

from ray.rllib.utils.deprecation import deprecation_warning
Expand Down
214 changes: 126 additions & 88 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,25 @@
import copy
import platform
import random
from typing import Tuple, Dict, List, DefaultDict, Set
from typing import Dict, List, DefaultDict, Set

import ray
from ray.actor import ActorHandle
from ray.rllib import RolloutWorker
from ray.rllib.agents import Trainer
from ray.rllib.algorithms.dqn.dqn import (
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
DQNTrainer,
)
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.parallel_requests import (
asynchronous_parallel_requests,
wait_asynchronous_requests,
)
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
Expand All @@ -53,39 +47,103 @@
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
SampleBatchType,
TrainerConfigDict,
ResultDict,
PartialTrainerConfigDict,
T,
)
from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.rllib.utils.deprecation import DEPRECATED_VALUE

# fmt: off
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
# See also the options in dqn.py, which are also supported.
DQN_DEFAULT_CONFIG,
{
"optimizer": merge_dicts(
DQN_DEFAULT_CONFIG["optimizer"], {
from ray.util.ml_utils.dict import merge_dicts


class ApexConfig(DQNConfig):
"""Defines a configuration class from which an ApexTrainer can be built.
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.replay_buffer_config)
>>> replay_config = config.replay_buffer_config.update(
>>> {
>>> "capacity": 100000,
>>> "prioritized_replay_alpha": 0.45,
>>> "prioritized_replay_beta": 0.55,
>>> "prioritized_replay_eps": 3e-6,
>>> }
>>> )
>>> config.training(replay_buffer_config=replay_config)\
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1")
>>> trainer = ApexTrainer(config=config)
>>> while True:
>>> trainer.train()
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> from ray import tune
>>> config = ApexConfig()
>>> config.training(num_atoms=tune.grid_search(list(range(1, 11)))
>>> config.environment(env="CartPole-v1")
>>> tune.run(
>>> "APEX",
>>> stop={"episode_reward_mean":200},
>>> config=config.to_dict()
>>> )
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "EpsilonGreedy",
>>> "initial_epsilon": 0.96,
>>> "final_epsilon": 0.01,
>>> "epsilone_timesteps": 5000,
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "SoftQ",
>>> "temperature": [1.0],
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
"""

def __init__(self, trainer_class=None):
"""Initializes a ApexConfig instance."""
super().__init__(trainer_class=trainer_class or ApexTrainer)

# fmt: off
# __sphinx_doc_begin__
# APEX-DQN settings overriding DQN ones:
# .training()
self.optimizer = merge_dicts(
DQNConfig().optimizer, {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False
}),
"n_step": 3,
"num_gpus": 1,
"num_workers": 32,

# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": {
})
self.n_step = 3
self.train_batch_size = 512
self.target_network_update_freq = 500000
self.training_intensity = 1
# APEX-DQN is using a distributed (non local) replay buffer.
self.replay_buffer_config = {
"no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization
"prioritized_replay": DEPRECATED_VALUE,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 2000000,
"replay_batch_size": 32,
Expand All @@ -106,63 +164,26 @@
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
},

"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
# Minimum env sampling timesteps to accumulate within a single `train()` call.
# This value does not affect learning, only the number of times
# `Trainer.step_attempt()` is called by `Trainer.train()`. If - after one
# `step_attempt()`, the env sampling timestep count has not been reached, will
# perform n more `step_attempt()` calls until the minimum timesteps have been
# executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 25000,
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"min_time_s_per_reporting": 30,
# This will set the ratio of replayed from a buffer and learned
# on timesteps to sampled from an environment and stored in the replay
# buffer timesteps. Must be greater than 0.
# TODO: Find a way to support None again as a means to replay
# proceeding as fast as possible.
"training_intensity": 1,
},
)
# __sphinx_doc_end__
# fmt: on


# Update worker weights as they finish generating experiences.
class UpdateWorkerWeights:
def __init__(
self,
learner_thread: LearnerThread,
workers: WorkerSet,
max_weight_sync_delay: int,
):
self.learner_thread = learner_thread
self.workers = workers
self.steps_since_update = defaultdict(int)
self.max_weight_sync_delay = max_weight_sync_delay
self.weights = None

def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
actor, batch = item
self.steps_since_update[actor] += batch.count
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors.
if self.weights is None or self.learner_thread.weights_updated:
self.learner_thread.weights_updated = False
self.weights = ray.put(self.workers.local_worker().get_weights())
actor.set_weights.remote(self.weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
self.steps_since_update[actor] = 0
# Update metrics.
metrics = _get_shared_metrics()
metrics.counters["num_weight_syncs"] += 1
# Deprecated key.
"prioritized_replay": DEPRECATED_VALUE,
}

# .rollouts()
self.num_workers = 32
self.rollout_fragment_length = 50
self.exploration_config = {
"type": "PerWorkerEpsilonGreedy",
}

# .resources()
self.num_gpus = 1

# .reporting()
self.min_time_s_per_reporting = 30
self.min_sample_timesteps_per_reporting = 25000

# fmt: on
# __sphinx_doc_end__


class ApexTrainer(DQNTrainer):
Expand Down Expand Up @@ -232,7 +253,7 @@ def setup(self, config: PartialTrainerConfigDict):
@classmethod
@override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DEFAULT_CONFIG
return ApexConfig().to_dict()

@override(DQNTrainer)
def validate_config(self, config):
Expand Down Expand Up @@ -548,3 +569,20 @@ def default_resource_request(cls, config):
),
strategy=config.get("placement_strategy", "PACK"),
)


# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(ApexConfig().to_dict())

@Deprecated(
old="ray.rllib.agents.dqn.apex.APEX_DEFAULT_CONFIG",
new="ray.rllib.agents.dqn.apex.ApexConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)


APEX_DEFAULT_CONFIG = _deprecated_default_config()
Loading