Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] PGTrainer config object class (PGConfig). #24295

Merged
merged 1 commit into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll

**PG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

.. literalinclude:: ../../../rllib/agents/pg/default_config.py
.. literalinclude:: ../../../rllib/agents/pg/pg.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/pg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.agents.pg.pg import DEFAULT_CONFIG, PGConfig, PGTrainer
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss, PGTorchPolicy
from ray.rllib.agents.pg.utils import post_process_advantages

__all__ = [
"DEFAULT_CONFIG",
"pg_tf_loss",
"pg_torch_loss",
"post_process_advantages",
"DEFAULT_CONFIG",
"PGConfig",
"PGTFPolicy",
"PGTorchPolicy",
"PGTrainer",
Expand Down
30 changes: 8 additions & 22 deletions rllib/agents/pg/default_config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
from ray.rllib.agents.trainer import with_common_config

# fmt: off
# __sphinx_doc_begin__

# Add the following (PG-specific) updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,

# Experimental: By default, switch off preprocessors for PG.
"_disable_preprocessor_api": True,

# Use new `training_iteration` API (instead of `execution_plan` method).
"_disable_execution_plan_api": True,
})

# __sphinx_doc_end__
# fmt: on
from ray.rllib.agents.pg import DEFAULT_CONFIG # noqa
from ray.rllib.utils.deprecation import deprecation_warning

deprecation_warning(
old="ray.rllib.agents.pg.default_config::DEFAULT_CONFIG (python dict)",
new="ray.rllib.agents.pg.pg::PGConfig() (RLlib TrainerConfig class)",
error=True,
)
75 changes: 73 additions & 2 deletions rllib/agents/pg/pg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,85 @@
from typing import Type

from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict


class PGConfig(TrainerConfig):
"""Defines a PGTrainer configuration class from which a PGTrainer can be built.

Example:
>>> config = PGConfig().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()

Example:
>>> config = PGConfig()
>>> # Print out some default values.
>>> print(config.lr)
... 0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "PG",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""

def __init__(self):
"""Initializes a PGConfig instance."""
super().__init__(trainer_class=PGTrainer)

# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with PG-specific values.
self.num_workers = 0
self.lr = 0.0004
self._disable_execution_plan_api = True
self._disable_preprocessor_api = True
# __sphinx_doc_end__
# fmt: on


# Deprecated: Use ray.rllib.agents.ppo.PGConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(
with_common_config(
{
# TrainerConfig overrides:
"num_workers": 0,
"lr": 0.0004,
"_disable_execution_plan_api": True,
"_disable_preprocessor_api": True,
}
)
)

@Deprecated(
old="ray.rllib.agents.pg.default_config::DEFAULT_CONFIG",
new="ray.rllib.agents.pg.pg.PGConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)


DEFAULT_CONFIG = _deprecated_default_config()


class PGTrainer(Trainer):
"""Policy Gradient (PG) Trainer.

Expand Down
30 changes: 19 additions & 11 deletions rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def tearDownClass(cls) -> None:

def test_pg_compilation(self):
"""Test whether a PGTrainer can be built with all frameworks."""
config = pg.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["rollout_fragment_length"] = 500
config = pg.PGConfig()
# Test with filter to see whether they work w/o preprocessing.
config["observation_filter"] = "MeanStdFilter"
config.rollouts(
num_rollout_workers=1,
rollout_fragment_length=500,
observation_filter="MeanStdFilter",
)
num_iterations = 1

image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
Expand Down Expand Up @@ -77,7 +79,7 @@ def test_pg_compilation(self):
"FrozenLake-v1",
]:
print(f"env={env}")
trainer = pg.PGTrainer(config=config, env=env)
trainer = config.build(env=env)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
Expand All @@ -87,11 +89,17 @@ def test_pg_compilation(self):

def test_pg_loss_functions(self):
"""Tests the PG loss function math."""
config = pg.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config = (
pg.PGConfig()
.rollouts(num_rollout_workers=0)
.training(
gamma=0.99,
model={
"fcnet_hiddens": [10],
"fcnet_activation": "linear",
},
)
)

# Fake CartPole episode of n time steps.
train_batch = SampleBatch(
Expand All @@ -109,7 +117,7 @@ def test_pg_loss_functions(self):

for fw, sess in framework_iterator(config, session=True):
dist_cls = Categorical if fw != "torch" else TorchCategorical
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
trainer = config.build(env="CartPole-v0")
policy = trainer.get_policy()
vars = policy.model.trainable_variables()
if sess:
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_ppo_compilation_and_schedule_mixins(self):
)
)

trainer = ppo.PPOTrainer(config=config, env=env)
trainer = config.build(env=env)
policy = trainer.get_policy()
entropy_coeff = trainer.get_policy().entropy_coeff
lr = policy.cur_lr
Expand Down
14 changes: 14 additions & 0 deletions rllib/agents/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ def to_dict(self) -> TrainerConfigDict:
config["input"] = getattr(self, "input_")
config.pop("input_")

# Setup legacy multiagent sub-dict:
config["multiagent"] = {}
for k in [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, didn't think of this one

"policies",
"policy_map_capacity",
"policy_map_cache",
"policy_mapping_fn",
"policies_to_train",
"observation_fn",
"replay_mode",
"count_steps_by",
]:
config["multiagent"][k] = config.pop(k)

# Switch out deprecated vs new config keys.
config["callbacks"] = config.pop("callbacks_class", DefaultCallbacks)
config["create_env_on_driver"] = config.pop("create_env_on_local_worker", 1)
Expand Down