From 55d685467fc5a4e8246a1c6d9a1970a946582877 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 28 Apr 2022 13:46:34 +0200 Subject: [PATCH] wip --- doc/source/rllib/rllib-algorithms.rst | 2 +- rllib/agents/pg/__init__.py | 5 +- rllib/agents/pg/default_config.py | 30 +++-------- rllib/agents/pg/pg.py | 75 ++++++++++++++++++++++++++- rllib/agents/pg/tests/test_pg.py | 30 +++++++---- rllib/agents/ppo/tests/test_ppo.py | 2 +- rllib/agents/trainer_config.py | 14 +++++ 7 files changed, 119 insertions(+), 39 deletions(-) diff --git a/doc/source/rllib/rllib-algorithms.rst b/doc/source/rllib/rllib-algorithms.rst index 8dfed85420a5..d48eecd5a168 100644 --- a/doc/source/rllib/rllib-algorithms.rst +++ b/doc/source/rllib/rllib-algorithms.rst @@ -360,7 +360,7 @@ Tuned examples: `CartPole-v0 `__): -.. literalinclude:: ../../../rllib/agents/pg/default_config.py +.. literalinclude:: ../../../rllib/agents/pg/pg.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ diff --git a/rllib/agents/pg/__init__.py b/rllib/agents/pg/__init__.py index 8b1044a859b1..3f5283c53021 100644 --- a/rllib/agents/pg/__init__.py +++ b/rllib/agents/pg/__init__.py @@ -1,13 +1,14 @@ -from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG +from ray.rllib.agents.pg.pg import DEFAULT_CONFIG, PGConfig, PGTrainer from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, PGTFPolicy from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss, PGTorchPolicy from ray.rllib.agents.pg.utils import post_process_advantages __all__ = [ + "DEFAULT_CONFIG", "pg_tf_loss", "pg_torch_loss", "post_process_advantages", - "DEFAULT_CONFIG", + "PGConfig", "PGTFPolicy", "PGTorchPolicy", "PGTrainer", diff --git a/rllib/agents/pg/default_config.py b/rllib/agents/pg/default_config.py index 473ba7eeb3b5..5b012fd3ac49 100644 --- a/rllib/agents/pg/default_config.py +++ b/rllib/agents/pg/default_config.py @@ -1,22 +1,8 @@ -from ray.rllib.agents.trainer import with_common_config - -# fmt: off -# __sphinx_doc_begin__ - -# Add the following (PG-specific) updates to the (base) `Trainer` config in -# rllib/agents/trainer.py (`COMMON_CONFIG` dict). -DEFAULT_CONFIG = with_common_config({ - # No remote workers by default. - "num_workers": 0, - # Learning rate. - "lr": 0.0004, - - # Experimental: By default, switch off preprocessors for PG. - "_disable_preprocessor_api": True, - - # Use new `training_iteration` API (instead of `execution_plan` method). - "_disable_execution_plan_api": True, -}) - -# __sphinx_doc_end__ -# fmt: on +from ray.rllib.agents.pg import DEFAULT_CONFIG # noqa +from ray.rllib.utils.deprecation import deprecation_warning + +deprecation_warning( + old="ray.rllib.agents.pg.default_config::DEFAULT_CONFIG (python dict)", + new="ray.rllib.agents.pg.pg::PGConfig() (RLlib TrainerConfig class)", + error=True, +) diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index 2e01b798417d..61de20cab54e 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -1,14 +1,85 @@ from typing import Type -from ray.rllib.agents.trainer import Trainer -from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG +from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TrainerConfigDict +class PGConfig(TrainerConfig): + """Defines a PGTrainer configuration class from which a PGTrainer can be built. + + Example: + >>> config = PGConfig().training(lr=0.01).resources(num_gpus=1) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run 1 training iteration. + >>> trainer = config.build(env="CartPole-v1") + >>> trainer.train() + + Example: + >>> config = PGConfig() + >>> # Print out some default values. + >>> print(config.lr) + ... 0.0004 + >>> # Update the config object. + >>> config.training(lr=tune.grid_search([0.001, 0.0001])) + >>> # Set the config object's env. + >>> config.environment(env="CartPole-v1") + >>> # Use to_dict() to get the old-style python config dict + >>> # when running with tune. + >>> tune.run( + ... "PG", + ... stop={"episode_reward_mean": 200}, + ... config=config.to_dict(), + ... ) + """ + + def __init__(self): + """Initializes a PGConfig instance.""" + super().__init__(trainer_class=PGTrainer) + + # fmt: off + # __sphinx_doc_begin__ + # Override some of TrainerConfig's default values with PG-specific values. + self.num_workers = 0 + self.lr = 0.0004 + self._disable_execution_plan_api = True + self._disable_preprocessor_api = True + # __sphinx_doc_end__ + # fmt: on + + +# Deprecated: Use ray.rllib.agents.ppo.PGConfig instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__( + with_common_config( + { + # TrainerConfig overrides: + "num_workers": 0, + "lr": 0.0004, + "_disable_execution_plan_api": True, + "_disable_preprocessor_api": True, + } + ) + ) + + @Deprecated( + old="ray.rllib.agents.pg.default_config::DEFAULT_CONFIG", + new="ray.rllib.agents.pg.pg.PGConfig(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +DEFAULT_CONFIG = _deprecated_default_config() + + class PGTrainer(Trainer): """Policy Gradient (PG) Trainer. diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 41cad796abdc..475580f136ff 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -30,11 +30,13 @@ def tearDownClass(cls) -> None: def test_pg_compilation(self): """Test whether a PGTrainer can be built with all frameworks.""" - config = pg.DEFAULT_CONFIG.copy() - config["num_workers"] = 1 - config["rollout_fragment_length"] = 500 + config = pg.PGConfig() # Test with filter to see whether they work w/o preprocessing. - config["observation_filter"] = "MeanStdFilter" + config.rollouts( + num_rollout_workers=1, + rollout_fragment_length=500, + observation_filter="MeanStdFilter", + ) num_iterations = 1 image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) @@ -77,7 +79,7 @@ def test_pg_compilation(self): "FrozenLake-v1", ]: print(f"env={env}") - trainer = pg.PGTrainer(config=config, env=env) + trainer = config.build(env=env) for i in range(num_iterations): results = trainer.train() check_train_results(results) @@ -87,11 +89,17 @@ def test_pg_compilation(self): def test_pg_loss_functions(self): """Tests the PG loss function math.""" - config = pg.DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. - config["gamma"] = 0.99 - config["model"]["fcnet_hiddens"] = [10] - config["model"]["fcnet_activation"] = "linear" + config = ( + pg.PGConfig() + .rollouts(num_rollout_workers=0) + .training( + gamma=0.99, + model={ + "fcnet_hiddens": [10], + "fcnet_activation": "linear", + }, + ) + ) # Fake CartPole episode of n time steps. train_batch = SampleBatch( @@ -109,7 +117,7 @@ def test_pg_loss_functions(self): for fw, sess in framework_iterator(config, session=True): dist_cls = Categorical if fw != "torch" else TorchCategorical - trainer = pg.PGTrainer(config=config, env="CartPole-v0") + trainer = config.build(env="CartPole-v0") policy = trainer.get_policy() vars = policy.model.trainable_variables() if sess: diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 3ece523a7df9..de7fdd67893e 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -136,7 +136,7 @@ def test_ppo_compilation_and_schedule_mixins(self): ) ) - trainer = ppo.PPOTrainer(config=config, env=env) + trainer = config.build(env=env) policy = trainer.get_policy() entropy_coeff = trainer.get_policy().entropy_coeff lr = policy.cur_lr diff --git a/rllib/agents/trainer_config.py b/rllib/agents/trainer_config.py index 66f4795b78a0..2a783ef93478 100644 --- a/rllib/agents/trainer_config.py +++ b/rllib/agents/trainer_config.py @@ -229,6 +229,20 @@ def to_dict(self) -> TrainerConfigDict: config["input"] = getattr(self, "input_") config.pop("input_") + # Setup legacy multiagent sub-dict: + config["multiagent"] = {} + for k in [ + "policies", + "policy_map_capacity", + "policy_map_cache", + "policy_mapping_fn", + "policies_to_train", + "observation_fn", + "replay_mode", + "count_steps_by", + ]: + config["multiagent"][k] = config.pop(k) + # Switch out deprecated vs new config keys. config["callbacks"] = config.pop("callbacks_class", DefaultCallbacks) config["create_env_on_driver"] = config.pop("create_env_on_local_worker", 1)