Skip to content

Commit

Permalink
[RLlib] Add separate learning rates for policy and alpha to SAC. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Aug 21, 2024
1 parent 2063395 commit c50e3b6
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 32 deletions.
4 changes: 4 additions & 0 deletions doc/source/rllib/doc_code/new_api_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@
.training(
model={"uses_new_env_runners": True},
replay_buffer_config={"type": "EpisodeReplayBuffer"},
# Note, new API stack SAC uses its own learning rates specific to actor,
# critic, and alpha. `lr` therefore needs to be set to `None`. See `actor_lr`,
# `critic_lr`, and `alpha_lr` for the specific learning rates, respectively.
lr=None,
)
)
# __enabling-new-api-stack-sa-sac-end__
Expand Down
40 changes: 20 additions & 20 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ py_test(
py_test(
name = "learning_tests_pendulum_cql_old_api_stack",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_continuous", "learning_tests_with_ray_data"],
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_continuous", "learning_tests_with_ray_data", "torch_only"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include the zipped json data file as well.
Expand Down Expand Up @@ -2386,16 +2386,16 @@ py_test(
srcs = ["examples/_old_api_stack/connectors/adapt_connector_policy.py"],
)

py_test(
name = "examples/_old_api_stack/connectors/self_play_with_policy_checkpoint",
main = "examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples", "old_api_stack"],
size = "small",
srcs = ["examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py"],
args = [
"--train_iteration=1" # Smoke test.
]
)
# py_test(
# name = "examples/_old_api_stack/connectors/self_play_with_policy_checkpoint",
# main = "examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py",
# tags = ["team:rllib", "exclusive", "examples", "old_api_stack"],
# size = "small",
# srcs = ["examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py"],
# args = [
# "--train_iteration=1" # Smoke test.
# ]
# )

# ----------------------
# New API stack
Expand Down Expand Up @@ -3278,15 +3278,15 @@ py_test(
# ....................................

# @HybridAPIStack
py_test(
name = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent",
main = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py"],
data = ["tests/data/cartpole/large.json"],
args = ["--as-test"]
)
# py_test(
# name = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent",
# main = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py",
# tags = ["team:rllib", "exclusive", "examples"],
# size = "large",
# srcs = ["examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py"],
# data = ["tests/data/cartpole/large.json"],
# args = ["--as-test"]
# )

#@OldAPIStack
# TODO (sven): Doesn't seem to learn at the moment. Uncomment once fixed.
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, algo_class=None):
self.lagrangian = False
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
self.lr = 3e-4

# Changes to Algorithm's/SACConfig's default:

Expand Down
74 changes: 73 additions & 1 deletion rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
deprecation_warning,
)
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.typing import RLModuleSpecType, ResultDict
from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType, ResultDict

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
Expand Down Expand Up @@ -88,6 +88,11 @@ def __init__(self, algo_class=None):
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
}
self.actor_lr = 3e-5
self.critic_lr = 3e-4
self.alpha_lr = 3e-4
# Set `lr` parameter to `None` and ensure it is not used.
self.lr = 3e-4
self.grad_clip = None
self.target_network_update_freq = 0

Expand Down Expand Up @@ -142,6 +147,9 @@ def training(
clip_actions: Optional[bool] = NotProvided,
grad_clip: Optional[float] = NotProvided,
optimization_config: Optional[Dict[str, Any]] = NotProvided,
actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
alpha_lr: Optional[LearningRateOrSchedule] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
_deterministic_loss: Optional[bool] = NotProvided,
_use_beta_distribution: Optional[bool] = NotProvided,
Expand Down Expand Up @@ -246,6 +254,56 @@ def training(
optimization_config: Config dict for optimization. Set the supported keys
`actor_learning_rate`, `critic_learning_rate`, and
`entropy_learning_rate` in here.
actor_lr: The learning rate (float) or learning rate schedule for the
policy in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: It is common practice (two-timescale approach) to use a smaller
learning rate for the policy than for the critic to ensure that the
critic gives adequate values for improving the policy.
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-5, one decimal less than the respective
learning rate of the critic (see `critic_lr`).
critic_lr: The learning rate (float) or learning rate schedule for the
critic in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: It is common practice (two-timescale approach) to use a smaller
learning rate for the policy than for the critic to ensure that the
critic gives adequate values for improving the policy.
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-4, one decimal higher than the respective
learning rate of the actor (policy) (see `actor_lr`).
alpha_lr: The learning rate (float) or learning rate schedule for the
hyperparameter alpha in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-4, identical to the critic learning rate (`lr`).
target_network_update_freq: Update the target network every
`target_network_update_freq` steps.
_deterministic_loss: Whether the loss should be calculated deterministically
Expand Down Expand Up @@ -296,6 +354,12 @@ def training(
self.grad_clip = grad_clip
if optimization_config is not NotProvided:
self.optimization = optimization_config
if actor_lr is not NotProvided:
self.actor_lr = actor_lr
if critic_lr is not NotProvided:
self.critic_lr = critic_lr
if alpha_lr is not NotProvided:
self.alpha_lr = alpha_lr
if target_network_update_freq is not NotProvided:
self.target_network_update_freq = target_network_update_freq
if _deterministic_loss is not NotProvided:
Expand Down Expand Up @@ -385,6 +449,14 @@ def validate(self) -> None:
"`EpisodeReplayBuffer`."
)

if self.enable_rl_module_and_learner and self.lr is not None:
raise ValueError(
"Basic learning rate parameter `lr` is not `None`. For SAC "
"use the specific learning rate parameters `actor_lr`, `critic_lr` "
"and `alpha_lr`, for the actor, critic, and the hyperparameter "
"`alpha`, respectively."
)

@override(AlgorithmConfig)
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
if self.rollout_fragment_length == "auto":
Expand Down
8 changes: 4 additions & 4 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def configure_optimizers_for_module(
optimizer_name="qf",
optimizer=optim_critic,
params=params_critic,
lr_or_lr_schedule=config.lr,
lr_or_lr_schedule=config.critic_lr,
)
# If necessary register also an optimizer for a twin Q network.
if config.twin_q:
Expand All @@ -72,7 +72,7 @@ def configure_optimizers_for_module(
optimizer_name="qf_twin",
optimizer=optim_twin_critic,
params=params_twin_critic,
lr_or_lr_schedule=config.lr,
lr_or_lr_schedule=config.critic_lr,
)

# Define the optimizer for the actor.
Expand All @@ -86,7 +86,7 @@ def configure_optimizers_for_module(
optimizer_name="policy",
optimizer=optim_actor,
params=params_actor,
lr_or_lr_schedule=config.lr,
lr_or_lr_schedule=config.actor_lr,
)

# Define the optimizer for the temperature.
Expand All @@ -97,7 +97,7 @@ def configure_optimizers_for_module(
optimizer_name="alpha",
optimizer=optim_temperature,
params=[temperature],
lr_or_lr_schedule=config.lr,
lr_or_lr_schedule=config.alpha_lr,
)

@override(DQNRainbowTorchLearner)
Expand Down
10 changes: 9 additions & 1 deletion rllib/algorithms/tests/test_worker_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,12 @@ def test_sync_replay(self):
)
.env_runners(env_runner_cls=ForwardHealthCheckToEnvWorker)
.reporting(min_sample_timesteps_per_iteration=1)
.training(replay_buffer_config={"type": "EpisodeReplayBuffer"})
.training(
replay_buffer_config={"type": "EpisodeReplayBuffer"},
# We need to set the base `lr` to `None` b/c SAC in the new stack
# has its own learning rates.
lr=None,
)
)

def test_multi_gpu(self):
Expand Down Expand Up @@ -787,6 +792,9 @@ def test_worker_failing_recover_with_hanging_workers(self):
)
.training(
replay_buffer_config={"type": "EpisodeReplayBuffer"},
# We need to set the base `lr` to `None` b/c new stack SAC has its
# specific learning rates for actor, critic, and alpha.
lr=None,
)
.env_runners(
env_runner_cls=ForwardHealthCheckToEnvWorker,
Expand Down
5 changes: 4 additions & 1 deletion rllib/tuned_examples/sac/benchmark_sac_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def stop_all(self):
# TODO (simon): Adjust to new model_config_dict.
.training(
initial_alpha=1.001,
lr=3e-4,
# Choose a smaller learning rate for the actor (policy).
actor_lr=3e-5,
critic_lr=3e-4,
alpha_lr=1e-4,
target_entropy="auto",
n_step=1,
tau=0.005,
Expand Down
8 changes: 6 additions & 2 deletions rllib/tuned_examples/sac/benchmark_sac_mujoco_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
# Copy bottom % with top % weights.
quantile_fraction=0.25,
hyperparam_bounds={
"lr": [1e-5, 1e-3],
"actor_lr": [1e-5, 1e-3],
"critic_lr": [1e-6, 1e-4],
"alpha_lr": [1e-6, 1e-3],
"gamma": [0.95, 0.99],
"n_step": [1, 3],
"initial_alpha": [1.0, 1.5],
Expand Down Expand Up @@ -80,7 +82,9 @@
# TODO (simon): Adjust to new model_config_dict.
.training(
initial_alpha=tune.choice([1.0, 1.5]),
lr=tune.uniform(1e-5, 1e-3),
actor_lr=tune.uniform(1e-5, 1e-3),
critic_lr=tune.uniform([1e-6, 1e-4]),
alpha_lr=tune.uniform([1e-6, 1e-3]),
target_entropy=tune.choice([-10, -5, -1, "auto"]),
n_step=tune.choice([1, 3, (1, 3)]),
tau=tune.uniform(0.001, 0.1),
Expand Down
5 changes: 4 additions & 1 deletion rllib/tuned_examples/sac/halfcheetah_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
initial_alpha=1.001,
# lr=0.0006 is very high, w/ 4 GPUs -> 0.0012
# Might want to lower it for better stability, but it does learn well.
lr=0.0004 * (args.num_gpus or 1) ** 0.5,
actor_lr=2e-4 * (args.num_gpus or 1) ** 0.5,
critic_lr=8e-4 * (args.num_gpus or 1) ** 0.5,
alpha_lr=9e-4 * (args.num_gpus or 1) ** 0.5,
lr=None,
target_entropy="auto",
n_step=(1, 5), # 1?
tau=0.005,
Expand Down
6 changes: 5 additions & 1 deletion rllib/tuned_examples/sac/multi_agent_pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
.environment("multi_agent_pendulum")
.training(
initial_alpha=1.001,
lr=0.001 * ((args.num_gpus or 1) ** 0.5),
# Use a smaller learning rate for the policy.
actor_lr=2e-4 * (args.num_gpus or 1) ** 0.5,
critic_lr=8e-4 * (args.num_gpus or 1) ** 0.5,
alpha_lr=9e-4 * (args.num_gpus or 1) ** 0.5,
lr=None,
target_entropy="auto",
n_step=(2, 5),
tau=0.005,
Expand Down
6 changes: 5 additions & 1 deletion rllib/tuned_examples/sac/pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
.environment("Pendulum-v1")
.training(
initial_alpha=1.001,
lr=0.001 * (args.num_gpus or 1) ** 0.5,
# Use a smaller learning rate for the policy.
actor_lr=2e-4 * (args.num_gpus or 1) ** 0.5,
critic_lr=8e-4 * (args.num_gpus or 1) ** 0.5,
alpha_lr=9e-4 * (args.num_gpus or 1) ** 0.5,
lr=None,
target_entropy="auto",
n_step=(2, 5),
tau=0.005,
Expand Down

0 comments on commit c50e3b6

Please sign in to comment.