Skip to content

Commit

Permalink
[RLlib] POC: PGTrainer class that works by sub-classing, not `train…
Browse files Browse the repository at this point in the history
…er_template.py`. (#20055)
  • Loading branch information
sven1977 authored Nov 11, 2021
1 parent 883fbd0 commit 6f85af4
Show file tree
Hide file tree
Showing 16 changed files with 368 additions and 247 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll

**PG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

.. literalinclude:: ../../rllib/agents/pg/pg.py
.. literalinclude:: ../../rllib/agents/pg/default_config.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_policy(self, policy=DEFAULT_POLICY_ID):
return self.policy

@override(Trainer)
def step(self):
def step_attempt(self):
config = self.config

theta = self.policy.get_flat_weights()
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def get_policy(self, policy=DEFAULT_POLICY_ID):
return self.policy

@override(Trainer)
def step(self):
def step_attempt(self):
config = self.config

theta = self.policy.get_flat_weights()
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/pg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \
post_process_advantages, PGTFPolicy
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss, PGTorchPolicy
from ray.rllib.agents.pg.utils import post_process_advantages

__all__ = [
"pg_tf_loss",
Expand Down
16 changes: 16 additions & 0 deletions rllib/agents/pg/default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ray.rllib.agents.trainer import with_common_config

# yapf: disable
# __sphinx_doc_begin__

# Add the following (PG-specific) updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})

# __sphinx_doc_end__
# yapf: enable
71 changes: 23 additions & 48 deletions rllib/agents/pg/pg.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,35 @@
"""
Policy Gradient (PG)
====================
from typing import Type

This file defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg
"""

import logging
from typing import Optional, Type

from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict

logger = logging.getLogger(__name__)

# yapf: disable
# __sphinx_doc_begin__
class PGTrainer(Trainer):
"""Policy Gradient (PG) Trainer.
# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})
Defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy losses for
TensorFlow and PyTorch.
# __sphinx_doc_end__
# yapf: enable
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#pg
Only overrides the default config- and policy selectors
(`get_default_policy` and `get_default_config`). Utilizes
the default `execution_plan()` of `Trainer`.
"""

def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
@override(Trainer)
def get_default_policy_class(self, config) -> Type[Policy]:
return PGTorchPolicy if config.get("framework") == "torch" \
else PGTFPolicy

Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
return PGTorchPolicy


# Build a child class of `Trainer`, which uses the framework specific Policy
# determined in `get_policy_class()` above.
PGTrainer = build_trainer(
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class,
)
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
2 changes: 1 addition & 1 deletion rllib/agents/pg/pg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def pg_tf_loss(
# - PG loss function
PGTFPolicy = build_tf_policy(
name="PGTFPolicy",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG,
postprocess_fn=post_process_advantages,
loss_fn=pg_tf_loss)
2 changes: 1 addition & 1 deletion rllib/agents/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def pg_loss_stats(policy: Policy,
PGTorchPolicy = build_policy_class(
name="PGTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=post_process_advantages,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/pg/tests/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def tearDownClass(cls) -> None:
ray.shutdown()

def test_pg_compilation(self):
"""Test whether a PGTrainer can be built with both frameworks."""
"""Test whether a PGTrainer can be built with all frameworks."""
config = pg.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["rollout_fragment_length"] = 500
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs):
pid = f"p{i}"
new_pol = trainer.add_policy(
pid,
trainer._policy_class,
trainer.get_default_policy_class(config),
# Test changing the mapping fn.
policy_mapping_fn=new_mapping_fn,
# Change the list of policies to train.
Expand Down
Loading

0 comments on commit 6f85af4

Please sign in to comment.