Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] POC: PGTrainer class that works by sub-classing, not trainer_template.py. #20055

Merged
merged 25 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll

**PG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

.. literalinclude:: ../../rllib/agents/pg/pg.py
.. literalinclude:: ../../rllib/agents/pg/default_config.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/pg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \
post_process_advantages, PGTFPolicy
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss, PGTorchPolicy
from ray.rllib.agents.pg.utils import post_process_advantages

__all__ = [
"pg_tf_loss",
Expand Down
16 changes: 16 additions & 0 deletions rllib/agents/pg/default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ray.rllib.agents.trainer import with_common_config

# yapf: disable
# __sphinx_doc_begin__

# Add the following (PG-specific) updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})

# __sphinx_doc_end__
# yapf: enable
70 changes: 22 additions & 48 deletions rllib/agents/pg/pg.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,34 @@
"""
Policy Gradient (PG)
====================
from typing import Type

This file defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.

Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg
"""

import logging
from typing import Optional, Type

from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict

logger = logging.getLogger(__name__)

# yapf: disable
# __sphinx_doc_begin__
class PGTrainer(Trainer):
"""Policy Gradient (PG) Trainer.

# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})
Defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.

# __sphinx_doc_end__
# yapf: enable
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#pg

Only overrides the default config- and policy selectors
(`get_default_policy` and `get_default_config`). Utilizes
the default execution plan of `Trainer`.
"""

def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.

Args:
config (TrainerConfigDict): The trainer's configuration dict.
@override(Trainer)
def get_default_policy_class(self, config) -> Type[Policy]:
return PGTorchPolicy if config.get("framework") == "torch" \
else PGTFPolicy

Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
return PGTorchPolicy


# Build a child class of `Trainer`, which uses the framework specific Policy
# determined in `get_policy_class()` above.
PGTrainer = build_trainer(
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class,
)
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
2 changes: 1 addition & 1 deletion rllib/agents/pg/pg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def pg_tf_loss(
# - PG loss function
PGTFPolicy = build_tf_policy(
name="PGTFPolicy",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG,
postprocess_fn=post_process_advantages,
loss_fn=pg_tf_loss)
2 changes: 1 addition & 1 deletion rllib/agents/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def pg_loss_stats(policy: Policy,
PGTorchPolicy = build_policy_class(
name="PGTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=post_process_advantages,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def tearDownClass(cls) -> None:
ray.shutdown()

def test_pg_compilation(self):
"""Test whether a PGTrainer can be built with both frameworks."""
"""Test whether a PGTrainer can be built with all frameworks."""
config = pg.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["rollout_fragment_length"] = 500
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs):
pid = f"p{i}"
new_pol = trainer.add_policy(
pid,
trainer._policy_class,
trainer.get_default_policy_class(config),
# Test changing the mapping fn.
policy_mapping_fn=new_mapping_fn,
# Change the list of policies to train.
Expand Down
Loading