Skip to content

Commit

Permalink
[RLlib] SimpleQ (minor cleanups) and DQN TrainerConfig objects. (#24584)
Browse files Browse the repository at this point in the history
  • Loading branch information
smorad authored May 15, 2022
1 parent de69b0d commit 5c96e72
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 138 deletions.
315 changes: 235 additions & 80 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
""" # noqa: E501

import logging
from typing import List, Optional, Type
from typing import List, Optional, Type, Callable

from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.dqn.simple_q import (
SimpleQConfig,
SimpleQTrainer,
)
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.execution.train_ops import (
train_one_step,
multi_gpu_train_one_step,
Expand All @@ -39,7 +39,6 @@
)
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
)
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
Expand All @@ -50,90 +49,243 @@

logger = logging.getLogger(__name__)

# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
SimpleQConfig().to_dict(),
{
# === Model ===
# Number of atoms for representing the distribution of return. When
# this is greater than 1, distributional Q-learning is used.
# the discrete supports are bounded by v_min and v_max
"num_atoms": 1,
"v_min": -10.0,
"v_max": 10.0,
# Whether to use noisy network
"noisy": False,
# control the initial value of noisy nets
"sigma0": 0.5,
# Whether to use dueling dqn
"dueling": True,
# Dense-layer setup for each the advantage branch and the value branch
# in a dueling architecture.
"hiddens": [256],
# Whether to use double dqn
"double_q": True,
# N-step Q learning
"n_step": 1,

# === Replay buffer ===
# Deprecated, use capacity in replay_buffer_config instead.
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
# Enable the new ReplayBuffer API.

class DQNConfig(SimpleQConfig):
"""Defines a DQNTrainer configuration class from which a DQNTrainer can be built.
Example:
>>> from ray.rllib.agents.dqn.dqn import DQNConfig
>>> config = DQNConfig()
>>> print(config.replay_buffer_config)
>>> replay_config = config.replay_buffer_config.update(
>>> {
>>> "capacity": 60000,
>>> "prioritized_replay_alpha": 0.5,
>>> "prioritized_replay_beta": 0.5,
>>> "prioritized_replay_eps": 3e-6,
>>> }
>>> )
>>> config.training(replay_buffer_config=replay_config)\
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=3)\
>>> .environment("CartPole-v1")
>>> trainer = DQNTrainer(config=config)
>>> while True:
>>> trainer.train()
Example:
>>> from ray.rllib.agents.dqn.dqn import DQNConfig
>>> from ray import tune
>>> config = DQNConfig()
>>> config.training(num_atoms=tune.grid_search(list(range(1,11)))
>>> config.environment(env="CartPole-v1")
>>> tune.run(
>>> "DQN",
>>> stop={"episode_reward_mean":200},
>>> config=config.to_dict()
>>> )
Example:
>>> from ray.rllib.agents.dqn.dqn import DQNConfig
>>> config = DQNConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "initial_epsilon": 1.5,
>>> "final_epsilon": 0.01,
>>> "epsilone_timesteps": 5000,
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
Example:
>>> from ray.rllib.agents.dqn.dqn import DQNConfig
>>> config = DQNConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "softq",
>>> "temperature": [1.0],
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
"""

def __init__(self):
"""Initializes a DQNConfig instance."""
super().__init__()

# DQN specific
# fmt: off
# __sphinx_doc_begin__
#
self.trainer_class = DQNTrainer
self.num_atoms = 1
self.v_min = -10.0
self.v_max = 10.0
self.noisy = False
self.sigma0 = 0.5
self.dueling = True
self.hiddens = [256]
self.double_q = True
self.n_step = 1
self.before_learn_on_batch = None
self.training_intensity = None
self.worker_side_prioritization = False

# Changes to SimpleQConfig default
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
# Size of the replay buffer. Note that if async_updates is set,
# then each worker will have a replay buffer of this size.
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
# - This is True AND restoring from a checkpoint that contains no buffer
# data.
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,


# Callback to run before learning on a multi-agent batch of
# experiences.
"before_learn_on_batch": None,

# The intensity with which to update the model (vs collecting samples
# from the env). If None, uses the "natural" value of:
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
# `num_envs_per_worker`).
# If provided, will make sure that the ratio between ts inserted into
# and sampled from the buffer matches the given value.
# Example:
# training_intensity=1000.0
# train_batch_size=250 rollout_fragment_length=1
# num_workers=1 (or 0) num_envs_per_worker=1
# -> natural value = 250 / 1 = 250.0
# -> will make sure that replay+train op will be executed 4x as
# often as rollout+insert op (4 * 250 = 1000).
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further
# details.
"training_intensity": None,

# === Parallelism ===
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
},
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# fmt: on
}
# fmt: on
# __sphinx_doc_end__

@override(SimpleQConfig)
def training(
self,
*,
num_atoms: Optional[int] = None,
v_min: Optional[float] = None,
v_max: Optional[float] = None,
noisy: Optional[bool] = None,
sigma0: Optional[float] = None,
dueling: Optional[bool] = None,
hiddens: Optional[int] = None,
double_q: Optional[bool] = None,
n_step: Optional[int] = None,
before_learn_on_batch: Callable[
[Type[MultiAgentBatch], List[Type[Policy]], Type[int]],
Type[MultiAgentBatch],
] = None,
training_intensity: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
replay_buffer_config: Optional[dict] = None,
**kwargs,
) -> "DQNConfig":
"""Sets the training related configuration.
Args:
num_atoms: Number of atoms for representing the distribution of return.
When this is greater than 1, distributional Q-learning is used.
v_min: Minimum value estimation
v_max: Maximum value estimation
noisy: Whether to use noisy network to aid exploration. This adds parametric
noise to the model weights.
sigma0: Control the initial parameter noise for noisy nets.
dueling: Whether to use dueling DQN.
hiddens: Dense-layer setup for each the advantage branch and the value
branch
double_q: Whether to use double DQN.
n_step: N-step for Q-learning.
before_learn_on_batch: Callback to run before learning on a multi-agent
batch of experiences.
training_intensity: The intensity with which to update the model (vs
collecting samples from the env).
If None, uses "natural" values of:
`train_batch_size` / (`rollout_fragment_length` x `num_workers` x
`num_envs_per_worker`).
If not None, will make sure that the ratio between timesteps inserted
into and sampled from th buffer matches the given values.
Example:
training_intensity=1000.0
train_batch_size=250
rollout_fragment_length=1
num_workers=1 (or 0)
num_envs_per_worker=1
-> natural value = 250 / 1 = 250.0
-> will make sure that replay+train op will be executed 4x asoften as
rollout+insert op (4 * 250 = 1000).
See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further details.
worker_side_prioritization: Whether to compute priorities on workers.
replay_buffer_config: Replay buffer config.
Examples:
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"learning_starts": 1000,
"capacity": 50000,
"replay_batch_size": 32,
"replay_sequence_length": 1,
}
- OR -
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"replay_sequence_length": 1,
}
- Where -
prioritized_replay_alpha: Alpha parameter controls the degree of
prioritization in the buffer. In other words, when a buffer sample has
a higher temporal-difference error, with how much more probability
should it drawn to use to update the parametrized Q-network. 0.0
corresponds to uniform probability. Setting much above 1.0 may quickly
result as the sampling distribution could become heavily “pointy” with
low entropy.
prioritized_replay_beta: Beta parameter controls the degree of
importance sampling which suppresses the influence of gradient updates
from samples that have higher probability of being sampled via alpha
parameter and the temporal-difference error.
prioritized_replay_eps: Epsilon parameter sets the baseline probability
for sampling so that when the temporal-difference error of a sample is
zero, there is still a chance of drawing the sample.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)

if num_atoms is not None:
self.num_atoms = num_atoms
if v_min is not None:
self.v_min = v_min
if v_max is not None:
self.v_max = v_max
if noisy is not None:
self.noisy = noisy
if sigma0 is not None:
self.sigma0 = sigma0
if dueling is not None:
self.dueling = dueling
if hiddens is not None:
self.hiddens = hiddens
if double_q is not None:
self.double_q = double_q
if n_step is not None:
self.n_step = n_step
if before_learn_on_batch is not None:
self.before_learn_on_batch = before_learn_on_batch
if training_intensity is not None:
self.training_intensity = training_intensity
if worker_side_prioritization is not None:
self.worker_side_priorizatiion = worker_side_prioritization
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config


# Deprecated: Use ray.rllib.agents.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DQNConfig().to_dict())

@Deprecated(
old="ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG",
new="ray.rllib.agents.dqn.dqn.DQNConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)


def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
Expand Down Expand Up @@ -271,6 +423,9 @@ def training_iteration(self) -> ResultDict:
return train_results


DEFAULT_CONFIG = _deprecated_default_config()


@Deprecated(
new="Sub-class directly from `DQNTrainer` and override its methods", error=False
)
Expand Down
Loading

0 comments on commit 5c96e72

Please sign in to comment.