Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add lr_schedule support to SimpleQ and PG. #28381

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion rllib/algorithms/pg/pg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type
from typing import List, Optional, Type, Union

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
Expand Down Expand Up @@ -47,11 +47,43 @@ def __init__(self):
# __sphinx_doc_begin__
# Override some of AlgorithmConfig's default values with PG-specific values.
self.num_workers = 0
self.lr_schedule = None
self.lr = 0.0004
self._disable_preprocessor_api = True
# __sphinx_doc_end__
# fmt: on

@override(AlgorithmConfig)
def training(
self,
*,
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
**kwargs,
) -> "PGConfig":
"""Sets the training related configuration.

Args:
gamma: Float specifying the discount factor of the Markov Decision process.
lr: The default learning rate.
train_batch_size: Training batch size, if applicable.
model: Arguments passed into the policy model. See models/catalog.py for a
full list of the available model options.
optimizer: Arguments to pass to the policy optimizer.
lr_schedule: Learning rate schedule. In the format of
[[timestep, lr-value], [timestep, lr-value], ...]
Intermediary timesteps will be assigned to interpolated learning rate
values. A schedule should normally start from timestep 0.

Returns:
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if lr_schedule is not None:
self.lr_schedule = lr_schedule

return self


class PG(Algorithm):
"""Policy Gradient (PG) Trainer.
Expand Down
13 changes: 12 additions & 1 deletion rllib/algorithms/pg/pg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import LearningRateSchedule
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType

Expand All @@ -41,6 +42,7 @@ def get_pg_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
"""

class PGTFPolicy(
LearningRateSchedule,
base,
):
def __init__(
Expand All @@ -66,6 +68,8 @@ def __init__(
existing_model=existing_model,
)

LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])

# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
Expand Down Expand Up @@ -124,9 +128,15 @@ def postprocess_trajectory(
self, sample_batch, other_agent_batches, episode
)

@override(base)
def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
return {
"learner_stats": {"cur_lr": self.cur_lr},
}

@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Returns the calculated loss in a stats dict.
"""Returns the calculated loss and learning rate in a stats dict.

Args:
policy: The Policy object.
Expand All @@ -138,6 +148,7 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:

return {
"policy_loss": self.policy_loss,
"cur_lr": self.cur_lr,
}

PGTFPolicy.__name__ = name
Expand Down
6 changes: 5 additions & 1 deletion rllib/algorithms/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import LearningRateSchedule
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

Expand All @@ -25,7 +26,7 @@
logger = logging.getLogger(__name__)


class PGTorchPolicy(TorchPolicyV2):
class PGTorchPolicy(LearningRateSchedule, TorchPolicyV2):
"""PyTorch policy class used with PGTrainer."""

def __init__(self, observation_space, action_space, config):
Expand All @@ -40,6 +41,8 @@ def __init__(self, observation_space, action_space, config):
max_seq_len=config["model"]["max_seq_len"],
)

LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])

# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()

Expand Down Expand Up @@ -100,6 +103,7 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
"policy_loss": torch.mean(
torch.stack(self.get_tower_stats("policy_loss"))
),
"cur_lr": self.cur_lr,
}
)

Expand Down
52 changes: 52 additions & 0 deletions rllib/algorithms/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
framework_iterator,
)
from ray import tune
from ray.rllib.utils.metrics.learner_info import (
LEARNER_INFO,
LEARNER_STATS_KEY,
DEFAULT_POLICY_ID,
)


class TestPG(unittest.TestCase):
Expand All @@ -31,6 +36,7 @@ def tearDownClass(cls) -> None:
def test_pg_compilation(self):
"""Test whether PG can be built with all frameworks."""
config = pg.PGConfig()

# Test with filter to see whether they work w/o preprocessing.
config.rollouts(
num_rollout_workers=1,
Expand Down Expand Up @@ -175,6 +181,52 @@ def test_pg_loss_functions(self):
expected_loss = -np.mean(expected_logp * adv)
check(results, expected_loss, decimals=4)

def test_pg_lr(self):
"""Test PG with learning rate schedule."""
config = pg.PGConfig()
config.reporting(
min_sample_timesteps_per_iteration=10,
# Make sure that results contain info on default policy
min_train_timesteps_per_iteration=10,
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
min_time_s_per_iteration=0,
)
config.rollouts(
num_rollout_workers=1,
rollout_fragment_length=50,
)
config.training(lr=0.2, lr_schedule=[[0, 0.2], [500, 0.001]])

def _step_n_times(algo, n: int):
"""Step trainer n times.

Returns:
learning rate at the end of the execution.
"""
for _ in range(n):
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"cur_lr"
]

for _ in framework_iterator(config):
algo = config.build(env="CartPole-v0")

lr = _step_n_times(algo, 1) # 50 timesteps
# Close to 0.2
self.assertGreaterEqual(lr, 0.15)

lr = _step_n_times(algo, 8) # Close to 500 timesteps
# LR Annealed to 0.001
self.assertLessEqual(float(lr), 0.5)

lr = _step_n_times(algo, 2) # > 500 timesteps
# LR == 0.001
self.assertAlmostEqual(lr, 0.001)

algo.stop()


if __name__ == "__main__":
import pytest
Expand Down
31 changes: 28 additions & 3 deletions rllib/algorithms/simple_q/simple_q_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import TargetNetworkMixin, compute_gradients
from ray.rllib.policy.tf_mixins import (
TargetNetworkMixin,
compute_gradients,
LearningRateSchedule,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import huber_loss
Expand Down Expand Up @@ -39,7 +43,7 @@ def get_simple_q_tf_policy(
A TF Policy to be used with MAMLTrainer.
"""

class SimpleQTFPolicy(TargetNetworkMixin, base):
class SimpleQTFPolicy(LearningRateSchedule, TargetNetworkMixin, base):
def __init__(
self,
obs_space,
Expand All @@ -66,6 +70,8 @@ def __init__(
existing_model=existing_model,
)

LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])

# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
Expand Down Expand Up @@ -184,7 +190,26 @@ def compute_gradients_fn(

@override(base)
def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
return {"td_error": self.td_error}
return {
"td_error": self.td_error,
"learner_stats": {"cur_lr": self.cur_lr},
}

@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Returns the learning rate in a stats dict.

Args:
policy: The Policy object.
train_batch: The data used for training.

Returns:
Dict[str, TensorType]: The stats dict.
"""

return {
"cur_lr": self.cur_lr,
}

def _compute_q_values(
self, model: ModelV2, obs_batch: TensorType, is_training=None
Expand Down
10 changes: 8 additions & 2 deletions rllib/algorithms/simple_q/simple_q_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TorchDistributionWrapper,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import TargetNetworkMixin
from ray.rllib.policy.torch_mixins import TargetNetworkMixin, LearningRateSchedule
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
Expand All @@ -28,6 +28,7 @@


class SimpleQTorchPolicy(
LearningRateSchedule,
TargetNetworkMixin,
TorchPolicyV2,
):
Expand All @@ -45,6 +46,8 @@ def __init__(self, observation_space, action_space, config):
max_seq_len=config["model"]["max_seq_len"],
)

LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])

# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()

Expand Down Expand Up @@ -166,7 +169,10 @@ def extra_compute_grad_fetches(self) -> Dict[str, Any]:
@override(TorchPolicyV2)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
return convert_to_numpy(
{"loss": torch.mean(torch.stack(self.get_tower_stats("loss")))}
{
"loss": torch.mean(torch.stack(self.get_tower_stats("loss"))),
"cur_lr": self.cur_lr,
}
)

def _compute_q_values(
Expand Down
51 changes: 51 additions & 0 deletions rllib/algorithms/simple_q/tests/test_simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
check_train_results,
framework_iterator,
)
from ray.rllib.utils.metrics.learner_info import (
LEARNER_INFO,
LEARNER_STATS_KEY,
DEFAULT_POLICY_ID,
)

tf1, tf, tfv = try_import_tf()

Expand Down Expand Up @@ -145,6 +150,52 @@ def test_simple_q_loss_function(self):
)
check(out, expected_loss, decimals=1)

def test_simple_q_lr_schedule(self):
"""Test PG with learning rate schedule."""
config = simple_q.SimpleQConfig()
config.reporting(
min_sample_timesteps_per_iteration=10,
# Make sure that results contain info on default policy
min_train_timesteps_per_iteration=10,
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
min_time_s_per_iteration=0,
)
config.rollouts(
num_rollout_workers=1,
rollout_fragment_length=50,
)
config.training(lr=0.2, lr_schedule=[[0, 0.2], [500, 0.001]])

def _step_n_times(algo, n: int):
"""Step trainer n times.

Returns:
learning rate at the end of the execution.
"""
for _ in range(n):
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"cur_lr"
]

for _ in framework_iterator(config):
algo = config.build(env="CartPole-v0")

lr = _step_n_times(algo, 1) # 50 timesteps
# Close to 0.2
self.assertGreaterEqual(lr, 0.15)

lr = _step_n_times(algo, 8) # Close to 500 timesteps
# LR Annealed to 0.001
self.assertLessEqual(float(lr), 0.5)

lr = _step_n_times(algo, 2) # > 500 timesteps
# LR == 0.001
self.assertAlmostEqual(lr, 0.001)

algo.stop()


if __name__ == "__main__":
import sys
Expand Down
3 changes: 3 additions & 0 deletions rllib/tests/test_nested_action_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def test_nested_action_spaces(self):
# Pretend actions in offline files are already normalized.
config["actions_in_input_normalized"] = True

# Remove lr schedule from config, not needed here, and not supported by BC.
del config["lr_schedule"]

for _ in framework_iterator(config):
for name, action_space in SPACES.items():
config["env_config"] = {
Expand Down