Skip to content

Commit

Permalink
[RLlib; new API stack by default] Switch on new API stack by default …
Browse files Browse the repository at this point in the history
…for SAC and DQN. (ray-project#47217)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 33cfcca commit aa4bb87
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 317 deletions.
2 changes: 0 additions & 2 deletions doc/source/rllib/doc_code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.framework("torch")
.environment("CartPole-v1")
.env_runners(num_env_runners=0)
.training(
Expand Down Expand Up @@ -113,7 +112,6 @@
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.framework("torch")
.environment("CartPole-v1")
.training(
replay_buffer_config={
Expand Down
11 changes: 9 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,13 @@ py_test(
srcs = ["algorithms/tests/test_callbacks_old_api_stack.py"]
)

py_test(
name = "test_node_failure",
tags = ["team:rllib", "tests_dir", "exclusive"],
size = "medium",
srcs = ["tests/test_node_failure.py"],
)

py_test(
name = "test_registry",
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
Expand Down Expand Up @@ -988,7 +995,7 @@ py_test(
name = "test_bc",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
# Include the offline data files.
# Include the parquet data files.
data = ["tests/data/cartpole/cartpole-v1_large"],
srcs = ["algorithms/bc/tests/test_bc.py"]
)
Expand Down Expand Up @@ -1053,7 +1060,7 @@ py_test(
name = "test_marwil",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
# Include the offline data files.
# Include the parquet data folder.
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/pendulum/pendulum-v1_large",
Expand Down
21 changes: 14 additions & 7 deletions rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,24 @@ def __init__(self, algo_class=None):
# Note, the new stack defines learning rates for each component.
# The base learning rate `lr` has to be set to `None`, if using
# the new stack.
self.actor_lr = 1e-4,
self.actor_lr = 1e-4
self.critic_lr = 1e-3
self.alpha_lr = 1e-3

# Changes to Algorithm's/SACConfig's default:
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": int(1e6),
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities already on the remote worker side.
"worker_side_prioritization": False,
}

# `.api_stack()`
self.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
# Changes to Algorithm's/SACConfig's default:
# .reporting()
self.min_sample_timesteps_per_iteration = 0
self.min_train_timesteps_per_iteration = 100
Expand Down
24 changes: 16 additions & 8 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,23 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

# Warn about new API stack on by default.
# Disallow hybrid API stack for DQN/SAC.
if self.enable_rl_module_and_learner:
logger.warning(
"You are running DQN on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)
if not self.enable_env_runner_and_connector_v2:
raise ValueError(
"Hybrid API stack (`enable_rl_module_and_learner=True` and "
"`enable_env_runner_and_connector_v2=False`) no longer supported "
"for DQN! Set both to True (recommended new API stack) or both to "
"False (old API stack)."
)
else:
logger.warning(
"You are running DQN on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

if (
not self.enable_rl_module_and_learner
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/impala/tests/test_vtrace_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
vtrace_torch,
make_time_major,
)
from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import vtrace_torch
from ray.rllib.algorithms.impala.tests.test_vtrace_old_api_stack import (
_ground_truth_vtrace_calculation,
)
Expand Down
139 changes: 9 additions & 130 deletions rllib/algorithms/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER, LR_KEY
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

from ray.rllib.utils.metrics import LEARNER_RESULTS
from ray.rllib.utils.test_utils import check, check_train_results_new_api_stack


def get_model_config(lstm=False):
def get_model_config(framework, lstm=False):
return (
dict(
use_lstm=True,
Expand Down Expand Up @@ -60,93 +60,16 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def test_ppo_compilation_w_connectors(self):
"""Test whether PPO can be built with all frameworks w/ connectors."""

# Build a PPOConfig object.
config = (
ppo.PPOConfig()
.training(
num_epochs=2,
# Setup lr schedule for testing.
lr_schedule=[[0, 5e-5], [128, 0.0]],
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
entropy_coeff_schedule=[[0, 0.1], [256, 0.0]],
train_batch_size=128,
model=dict(
# Settings in case we use an LSTM.
lstm_cell_size=10,
max_seq_len=20,
),
)
.env_runners(
num_env_runners=1,
# Test with compression.
compress_observations=True,
enable_connectors=True,
)
.callbacks(MyCallbacks)
.evaluation(
evaluation_duration=2,
evaluation_duration_unit="episodes",
evaluation_num_env_runners=1,
)
) # For checking lr-schedule correctness.

num_iterations = 2

for env in ["FrozenLake-v1", "ALE/MsPacman-v5"]:
print("Env={}".format(env))
for lstm in [False, True]:
print("LSTM={}".format(lstm))
config.training(
model=dict(
use_lstm=lstm,
lstm_use_prev_action=lstm,
lstm_use_prev_reward=lstm,
)
)

algo = config.build(env=env)
policy = algo.get_policy()
entropy_coeff = algo.get_policy().entropy_coeff
lr = policy.cur_lr
check(entropy_coeff, 0.1)
check(lr, config.lr)

for i in range(num_iterations):
results = algo.train()
check_train_results(results)
print(results)

algo.evaluate()

check_inference_w_connectors(policy, env_name=env)
algo.stop()

def test_ppo_compilation_and_schedule_mixins(self):
"""Test whether PPO can be built with all frameworks."""

# Build a PPOConfig object with the `SingleAgentEnvRunner` class.
config = (
ppo.PPOConfig()
.training(
# Setup lr schedule for testing.
lr_schedule=[[0, 5e-5], [256, 0.0]],
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
entropy_coeff_schedule=[[0, 0.1], [512, 0.0]],
train_batch_size=256,
minibatch_size=128,
num_epochs=2,
model=dict(
# Settings in case we use an LSTM.
lstm_cell_size=10,
max_seq_len=20,
),
# Enable new API stack and use EnvRunner.
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.env_runners(num_env_runners=0)
.training(
Expand Down Expand Up @@ -179,7 +102,9 @@ def test_ppo_compilation_and_schedule_mixins(self):
print("Env={}".format(env))
for lstm in [False]:
print("LSTM={}".format(lstm))
config.rl_module(model_config=get_model_config(lstm=lstm))
config.rl_module(
model_config_dict=get_model_config("torch", lstm=lstm)
).framework(eager_tracing=False)

algo = config.build(env=env)
# TODO: Maybe add an API to get the Learner(s) instances within
Expand All @@ -205,52 +130,6 @@ def test_ppo_compilation_and_schedule_mixins(self):
# algo.evaluate()
algo.stop()

def test_ppo_free_log_std(self):
"""Tests the free log std option works."""
config = (
ppo.PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("Pendulum-v1")
.env_runners(
num_env_runners=1,
)
.rl_module(
model_config=DefaultModelConfig(
fcnet_hiddens=[10],
fcnet_activation="linear",
free_log_std=True,
vf_share_layers=True,
),
)
.training(
gamma=0.99,
)
)

algo = config.build()
module = algo.get_module(DEFAULT_MODULE_ID)

# Check the free log std var is created.
matching = [v for (n, v) in module.named_parameters() if "log_std" in n]
assert len(matching) == 1, matching
log_std_var = matching[0]

def get_value(log_std_var=log_std_var):
return log_std_var.detach().cpu().numpy()[0]

# Check the variable is initially zero.
init_std = get_value()
assert init_std == 0.0, init_std
algo.train()

# Check the variable is updated.
post_std = get_value()
assert post_std != 0.0, post_std
algo.stop()


if __name__ == "__main__":
import pytest
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/ppo/tests/test_ppo_old_api_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def test_ppo_compilation_w_connectors(self):
num_env_runners=1,
# Test with compression.
compress_observations=True,
enable_connectors=True,
)
.callbacks(MyCallbacks)
.evaluation(
Expand Down
23 changes: 15 additions & 8 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,21 @@ def validate(self) -> None:
"and `alpha_lr`, for the actor, critic, and the hyperparameter "
"`alpha`, respectively and set `config.lr` to None."
)
# Warn about new API stack on by default.
logger.warning(
"You are running SAC on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)
elif not self.enable_env_runner_and_connector_v2:
raise ValueError(
"Hybrid API stack (`enable_rl_module_and_learner=True` and "
"`enable_env_runner_and_connector_v2=False`) no longer supported "
"for SAC! Set both to True (recommended new API stack) or both to "
"False (old API stack)."
)
else:
logger.warning(
"You are running SAC on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

@override(AlgorithmConfig)
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
Expand Down
Loading

0 comments on commit aa4bb87

Please sign in to comment.