diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 497991fd3f71..ac7924e149bc 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -358,7 +358,7 @@ Tuned examples: `CartPole-v0 `__): -.. literalinclude:: ../../rllib/agents/pg/pg.py +.. literalinclude:: ../../rllib/agents/pg/default_config.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ diff --git a/rllib/agents/ars/ars.py b/rllib/agents/ars/ars.py index c57278152970..a4acf6a3f78e 100644 --- a/rllib/agents/ars/ars.py +++ b/rllib/agents/ars/ars.py @@ -245,7 +245,7 @@ def get_policy(self, policy=DEFAULT_POLICY_ID): return self.policy @override(Trainer) - def step(self): + def step_attempt(self): config = self.config theta = self.policy.get_flat_weights() diff --git a/rllib/agents/es/es.py b/rllib/agents/es/es.py index 1663d3fd2598..2f6c68fe9ab9 100644 --- a/rllib/agents/es/es.py +++ b/rllib/agents/es/es.py @@ -258,7 +258,7 @@ def get_policy(self, policy=DEFAULT_POLICY_ID): return self.policy @override(Trainer) - def step(self): + def step_attempt(self): config = self.config theta = self.policy.get_flat_weights() diff --git a/rllib/agents/pg/__init__.py b/rllib/agents/pg/__init__.py index b2592a6c67d2..8b1044a859b1 100644 --- a/rllib/agents/pg/__init__.py +++ b/rllib/agents/pg/__init__.py @@ -1,7 +1,7 @@ from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG -from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \ - post_process_advantages, PGTFPolicy +from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, PGTFPolicy from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss, PGTorchPolicy +from ray.rllib.agents.pg.utils import post_process_advantages __all__ = [ "pg_tf_loss", diff --git a/rllib/agents/pg/default_config.py b/rllib/agents/pg/default_config.py new file mode 100644 index 000000000000..4386a6c2fd03 --- /dev/null +++ b/rllib/agents/pg/default_config.py @@ -0,0 +1,16 @@ +from ray.rllib.agents.trainer import with_common_config + +# yapf: disable +# __sphinx_doc_begin__ + +# Add the following (PG-specific) updates to the (base) `Trainer` config in +# rllib/agents/trainer.py (`COMMON_CONFIG` dict). +DEFAULT_CONFIG = with_common_config({ + # No remote workers by default. + "num_workers": 0, + # Learning rate. + "lr": 0.0004, +}) + +# __sphinx_doc_end__ +# yapf: enable diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index 2a9522cd82a0..046fda496dcc 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -1,60 +1,35 @@ -""" -Policy Gradient (PG) -==================== +from typing import Type -This file defines the distributed Trainer class for policy gradients. -See `pg_[tf|torch]_policy.py` for the definition of the policy loss. - -Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#pg -""" - -import logging -from typing import Optional, Type - -from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.agents.trainer import Trainer +from ray.rllib.agents.pg.default_config import DEFAULT_CONFIG from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict -logger = logging.getLogger(__name__) -# yapf: disable -# __sphinx_doc_begin__ +class PGTrainer(Trainer): + """Policy Gradient (PG) Trainer. -# Adds the following updates to the (base) `Trainer` config in -# rllib/agents/trainer.py (`COMMON_CONFIG` dict). -DEFAULT_CONFIG = with_common_config({ - # No remote workers by default. - "num_workers": 0, - # Learning rate. - "lr": 0.0004, -}) + Defines the distributed Trainer class for policy gradients. + See `pg_[tf|torch]_policy.py` for the definition of the policy losses for + TensorFlow and PyTorch. -# __sphinx_doc_end__ -# yapf: enable + Detailed documentation: + https://docs.ray.io/en/master/rllib-algorithms.html#pg + Only overrides the default config- and policy selectors + (`get_default_policy` and `get_default_config`). Utilizes + the default `execution_plan()` of `Trainer`. + """ -def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: - """Policy class picker function. Class is chosen based on DL-framework. - - Args: - config (TrainerConfigDict): The trainer's configuration dict. + @override(Trainer) + def get_default_policy_class(self, config) -> Type[Policy]: + return PGTorchPolicy if config.get("framework") == "torch" \ + else PGTFPolicy - Returns: - Optional[Type[Policy]]: The Policy class to use with PGTrainer. - If None, use `default_policy` provided in build_trainer(). - """ - if config["framework"] == "torch": - return PGTorchPolicy - - -# Build a child class of `Trainer`, which uses the framework specific Policy -# determined in `get_policy_class()` above. -PGTrainer = build_trainer( - name="PG", - default_config=DEFAULT_CONFIG, - default_policy=PGTFPolicy, - get_policy_class=get_policy_class, -) + @classmethod + @override(Trainer) + def get_default_config(cls) -> TrainerConfigDict: + return DEFAULT_CONFIG diff --git a/rllib/agents/pg/pg_tf_policy.py b/rllib/agents/pg/pg_tf_policy.py index f5cd970acc93..4a2808410dd1 100644 --- a/rllib/agents/pg/pg_tf_policy.py +++ b/rllib/agents/pg/pg_tf_policy.py @@ -51,6 +51,6 @@ def pg_tf_loss( # - PG loss function PGTFPolicy = build_tf_policy( name="PGTFPolicy", - get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, + get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG, postprocess_fn=post_process_advantages, loss_fn=pg_tf_loss) diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index cbc7ab2306a8..8e71701aee70 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -79,7 +79,7 @@ def pg_loss_stats(policy: Policy, PGTorchPolicy = build_policy_class( name="PGTorchPolicy", framework="torch", - get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, + get_default_config=lambda: ray.rllib.agents.pg.DEFAULT_CONFIG, loss_fn=pg_torch_loss, stats_fn=pg_loss_stats, postprocess_fn=post_process_advantages, diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 8dcc492c1279..c1f8c1315fe9 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -22,7 +22,7 @@ def tearDownClass(cls) -> None: ray.shutdown() def test_pg_compilation(self): - """Test whether a PGTrainer can be built with both frameworks.""" + """Test whether a PGTrainer can be built with all frameworks.""" config = pg.DEFAULT_CONFIG.copy() config["num_workers"] = 1 config["rollout_fragment_length"] = 500 diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index 937206deac13..69da676010ee 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -83,7 +83,7 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs): pid = f"p{i}" new_pol = trainer.add_policy( pid, - trainer._policy_class, + trainer.get_default_policy_class(config), # Test changing the mapping fn. policy_mapping_fn=new_mapping_fn, # Change the list of policies to train. diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 753f0e461755..822e9788e423 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1,3 +1,4 @@ +import concurrent import copy from datetime import datetime import functools @@ -23,13 +24,16 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.replay_buffer import LocalReplayBuffer +from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches +from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils import deep_update, FilterManager, merge_dicts -from ray.rllib.utils.annotations import DeveloperAPI, override, \ - PublicAPI +from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, \ + override, PublicAPI from ray.rllib.utils.debug import update_global_seed_if_necessary from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \ DEPRECATED_VALUE @@ -586,9 +590,9 @@ class Trainer(Trainable): @PublicAPI def __init__(self, - config: TrainerConfigDict = None, - env: Union[str, EnvType, None] = None, - logger_creator: Callable[[], Logger] = None, + config: Optional[PartialTrainerConfigDict] = None, + env: Optional[Union[str, EnvType]] = None, + logger_creator: Optional[Callable[[], Logger]] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None): """Initializes a Trainer instance. @@ -604,14 +608,17 @@ class directly. Note that this arg can also be specified via object. If unspecified, a default logger is created. """ - # User provided config (this is w/o the default Trainer's - # `COMMON_CONFIG` (see above)). Will get merged with COMMON_CONFIG - # in self.setup(). + # User provided (partial) config (this may be w/o the default + # Trainer's `COMMON_CONFIG` (see above)). Will get merged with + # COMMON_CONFIG in self.setup(). config = config or {} # Trainers allow env ids to be passed directly to the constructor. self._env_id = self._register_if_needed( env or config.get("env"), config) + # The env creator callable, taking an EnvContext (config dict) + # as arg and returning an RLlib supported Env type (e.g. a gym.Env). + self.env_creator: Callable[[EnvContext], EnvType] = None # Placeholder for a local replay buffer instance. self.local_replay_buffer = None @@ -659,8 +666,8 @@ def setup(self, config: PartialTrainerConfigDict): # Setup our config: Merge the user-supplied config (which could # be a partial config dict with the class' default). - self.config = self.merge_trainer_configs(self._default_config, config, - self._allow_unknown_configs) + self.config = self.merge_trainer_configs( + self.get_default_config(), config, self._allow_unknown_configs) # Setup the "env creator" callable. env = self._env_id @@ -705,18 +712,19 @@ def env_creator_from_classpath(env_context): logger.info( f"Executing eagerly (framework='{self.config['framework']}')," f" with eager_tracing={self.config['eager_tracing']}. For " - "production workloads, make sure to set eager_tracing=True in" - " order to match the speed of tf-static-graph " - "(framework='tf')") + "production workloads, make sure to set `eager_tracing=True` " + "in order to match the speed of tf-static-graph " + "(framework='tf'). For debugging purposes, " + "`eager_tracing=False` is the best choice.") # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and # enabling eager tracing for similar speed. elif tf1 and self.config["framework"] == "tf": logger.info( "Your framework setting is 'tf', meaning you are using static" "-graph mode. Set framework='tf2' to enable eager execution " - "with tf2.x. You may also want to then set eager_tracing=True" - " in order to reach similar execution speed as with " - "static-graph mode.") + "with tf2.x. You may also want to then set " + "`eager_tracing=True` in order to reach similar execution " + "speed as with static-graph mode.") # Set Trainer's seed after we have - if necessary - enabled # tf eager-execution. @@ -742,11 +750,42 @@ def env_creator_from_classpath(env_context): self.local_replay_buffer = ( self._create_local_replay_buffer_if_necessary(self.config)) - # Make the call to self._init. Sub-classes should override this - # method to implement custom initialization logic. - self._init(self.config, self.env_creator) - - # Evaluation setup. + # Deprecated way of implementing Trainer sub-classes (or "templates" + # via the soon-to-be deprecated `build_trainer` utility function). + # Instead, sub-classes should override the Trainable's `setup()` + # method and call super().setup() from within that override at some + # point. + self.workers = None + self.train_exec_impl = None + + # Old design: Override `Trainer._init` (or use `build_trainer()`, which + # will do this for you). + try: + self._init(self.config, self.env_creator) + # New design: Override `Trainable.setup()` (as indented by Trainable) + # and do or don't call super().setup() from within your override. + # By default, `super().setup()` will create both worker sets: + # "rollout workers" for collecting samples for training and - if + # applicable - "evaluation workers" for evaluation runs in between or + # parallel to training. + # TODO: Deprecate `_init()` and remove this try/except block. + except NotImplementedError: + # Only if user did not override `_init()`: + # - Create rollout workers here automatically. + # - Run the execution plan to create the local iterator to `next()` + # in each training iteration. + # This matches the behavior of using `build_trainer()`, which + # should no longer be used. + self.workers = self._make_workers( + env_creator=self.env_creator, + validate_env=self.validate_env, + policy_class=self.get_default_policy_class(self.config), + config=self.config, + num_workers=self.config["num_workers"]) + self.train_exec_impl = self.execution_plan( + self.workers, self.config, **self._kwargs_for_execution_plan()) + + # Evaluation WorkerSet setup. self.evaluation_workers = None self.evaluation_metrics = {} # User would like to setup a separate evaluation worker set. @@ -777,58 +816,165 @@ def env_creator_from_classpath(env_context): self.evaluation_workers = self._make_workers( env_creator=self.env_creator, validate_env=None, - policy_class=self._policy_class, + policy_class=self.get_default_policy_class(self.config), config=evaluation_config, num_workers=self.config["evaluation_num_workers"]) - @DeveloperAPI + # TODO: Deprecated: In your sub-classes of Trainer, override `setup()` + # directly and call super().setup() from within it if you would like the + # default setup behavior plus some own setup logic. + # If you don't need the env/workers/config/etc.. setup for you by super, + # simply do not call super().setup() from your overridden setup. def _init(self, config: TrainerConfigDict, env_creator: Callable[[EnvContext], EnvType]) -> None: - """Subclasses should override this for custom initialization. + raise NotImplementedError - In the case of Trainer, this is called from inside `self.setup()`. + @ExperimentalAPI + def get_default_policy_class(self, config: PartialTrainerConfigDict): + """Returns a default Policy class to use, given a config. - Args: - config: Algorithm-specific configuration dict. - env_creator: A callable taking an EnvContext as only arg and - returning an environment (of any type: e.g. gym.Env, RLlib - BaseEnv, MultiAgentEnv, etc..). + This class will be used inside RolloutWorkers' PolicyMaps in case + the policy class is not provided by the user in any single- or + multi-agent PolicySpec. + + This method is experimental and currently only used, iff the Trainer + class was not created using the `build_trainer` utility and if + the Trainer sub-class does not override `_init()` and create it's + own WorkerSet in `_init()`. """ - raise NotImplementedError + return getattr(self, "_policy_class", None) @override(Trainable) - @PublicAPI - def train(self) -> ResultDict: - """Overrides super.train to synchronize global vars.""" + def step(self) -> ResultDict: + """Implements the main `Trainer.train()` logic. + + Takes n attempts to perform a single training step. Thereby + catches RayErrors resulting from worker failures. After n attempts, + fails gracefully. + Override this method in your Trainer sub-classes if you would like to + handle worker failures yourself. Otherwise, override + `self.step_attempt()` to keep the n attempts (catch worker failures). + + Returns: + The results dict with stats/infos on sampling, training, + and - if required - evaluation. + """ result = None for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): + # Try to train one step. try: - result = Trainable.train(self) + result = self.step_attempt() + # @ray.remote RolloutWorker failure -> Try to recover, + # if necessary. except RayError as e: if self.config["ignore_worker_failures"]: logger.exception( "Error in train call, attempting to recover") - self._try_recover() + self.try_recover_from_step_attempt() else: logger.info( "Worker crashed during call to train(). To attempt to " "continue training without the failed worker, set " "`'ignore_worker_failures': True`.") raise e + # Any other exception. except Exception as e: - time.sleep(0.5) # allow logs messages to propagate + # Allow logs messages to propagate. + time.sleep(0.5) raise e else: break + + # Still no result (even after n retries). if result is None: - raise RuntimeError("Failed to recover from worker crash") + raise RuntimeError("Failed to recover from worker crash.") if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): self._sync_filters_if_needed(self.workers) return result + @ExperimentalAPI + def step_attempt(self) -> ResultDict: + """Attempts a single training step, including evaluation, if required. + + Override this method in your Trainer sub-classes if you would like to + keep the n attempts (catch worker failures) or override `step()` + directly if you would like to handle worker failures yourself. + + Returns: + The results dict with stats/infos on sampling, training, + and - if required - evaluation. + """ + + # self._iteration gets incremented after this function returns, + # meaning that e. g. the first time this function is called, + # self._iteration will be 0. + evaluate_this_iter = \ + self.config["evaluation_interval"] and \ + (self._iteration + 1) % self.config["evaluation_interval"] == 0 + + # No evaluation necessary, just run the next training iteration. + if not evaluate_this_iter: + step_results = next(self.train_exec_impl) + # We have to evaluate in this training iteration. + else: + # No parallelism. + if not self.config["evaluation_parallel_to_training"]: + step_results = next(self.train_exec_impl) + + # Kick off evaluation-loop (and parallel train() call, + # if requested). + # Parallel eval + training. + if self.config["evaluation_parallel_to_training"]: + with concurrent.futures.ThreadPoolExecutor() as executor: + train_future = executor.submit( + lambda: next(self.train_exec_impl)) + if self.config["evaluation_num_episodes"] == "auto": + + # Run at least one `evaluate()` (num_episodes_done + # must be > 0), even if the training is very fast. + def episodes_left_fn(num_episodes_done): + if num_episodes_done > 0 and \ + train_future.done(): + return 0 + else: + return self.config["evaluation_num_workers"] + + evaluation_metrics = self.evaluate( + episodes_left_fn=episodes_left_fn) + else: + evaluation_metrics = self.evaluate() + # Collect the training results from the future. + step_results = train_future.result() + # Sequential: train (already done above), then eval. + else: + evaluation_metrics = self.evaluate() + + # Add evaluation results to train results. + assert isinstance(evaluation_metrics, dict), \ + "Trainer.evaluate() needs to return a dict." + step_results.update(evaluation_metrics) + + # Check `env_task_fn` for possible update of the env's task. + if self.config["env_task_fn"] is not None: + if not callable(self.config["env_task_fn"]): + raise ValueError( + "`env_task_fn` must be None or a callable taking " + "[train_results, env, env_ctx] as args!") + + def fn(env, env_context, task_fn): + new_task = task_fn(step_results, env, env_context) + cur_task = env.get_task() + if cur_task != new_task: + env.set_task(new_task) + + fn = functools.partial(fn, task_fn=self.config["env_task_fn"]) + self.workers.foreach_env_with_context(fn) + + return step_results + @PublicAPI def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None ) -> dict: @@ -943,6 +1089,40 @@ def episodes_left_fn(num_episodes_done): self.evaluation_workers.remote_workers()) return {"evaluation": metrics} + @DeveloperAPI + @staticmethod + def execution_plan(workers, config, **kwargs): + + # Collects experiences in parallel from multiple RolloutWorker actors. + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + # Combine experiences batches until we hit `train_batch_size` in size. + # Then, train the policy on those experiences and update the workers. + train_op = rollouts.combine( + ConcatBatches( + min_batch_size=config["train_batch_size"], + count_steps_by=config["multiagent"]["count_steps_by"], + )) + + if config.get("simple_optimizer") is True: + train_op = train_op.for_each(TrainOneStep(workers)) + else: + train_op = train_op.for_each( + MultiGPUTrainOneStep( + workers=workers, + sgd_minibatch_size=config.get("sgd_minibatch_size", + config["train_batch_size"]), + num_sgd_iter=config.get("num_sgd_iter", 1), + num_gpus=config["num_gpus"], + shuffle_sequences=config.get("shuffle_sequences", False), + _fake_gpus=config["_fake_gpus"], + framework=config["framework"])) + + # Add on the standard episode reward, etc. metrics reporting. This + # returns a LocalIterator[metrics_dict] representing metrics for each + # train step. + return StandardMetricsReporting(train_op, workers, config) + @PublicAPI def compute_single_action( self, @@ -1478,7 +1658,7 @@ def default_resource_request( # workers to determine their CPU/GPU resource needs. # Convenience config handles. - cf = dict(cls._default_config, **config) + cf = dict(cls.get_default_config(), **config) eval_cf = cf["evaluation_config"] # TODO(ekl): add custom resources here once tune supports them @@ -1577,13 +1757,20 @@ def _sync_weights_to_workers( @property def _name(self) -> str: - """Subclasses should override this to declare their name.""" - raise NotImplementedError + """Subclasses may override this to declare their name.""" + # By default, return the class' name. + return type(self).__name__ + # TODO: Deprecate. Instead, override `Trainer.get_default_config()`. @property def _default_config(self) -> TrainerConfigDict: """Subclasses should override this to declare their default config.""" - raise NotImplementedError + return {} + + @ExperimentalAPI + @classmethod + def get_default_config(cls) -> TrainerConfigDict: + return cls._default_config or COMMON_CONFIG @classmethod @override(Trainable) @@ -1789,17 +1976,39 @@ def _validate_config(config: PartialTrainerConfigDict, "`evaluation_num_episodes` ({}) must be an int and " ">0!".format(config["evaluation_num_episodes"])) - def _try_recover(self): + @ExperimentalAPI + @staticmethod + def validate_env(env: EnvType, env_context: EnvContext) -> None: + """Env validator function for this Trainer class. + + Override this in child classes to define custom validation + behavior. + + Args: + env: The (sub-)environment to validate. This is normally a + single sub-environment (e.g. a gym.Env) within a vectorized + setup. + env_context: The EnvContext to configure the environment. + + Raises: + Exception in case something is wrong with the given environment. + """ + pass + + def try_recover_from_step_attempt(self) -> None: """Try to identify and remove any unhealthy workers. This method is called after an unexpected remote error is encountered - from a worker. It issues check requests to all current workers and + from a worker during the call to `self.step_attempt()` (within + `self.step()`). It issues check requests to all current workers and removes any that respond with error. If no healthy workers remain, - an error is raised. + an error is raised. Otherwise, tries to re-build the execution plan + with the remaining (healthy) workers. """ - assert hasattr(self, "execution_plan") - workers = self.workers + workers = getattr(self, "workers", None) + if not isinstance(workers, WorkerSet): + return logger.info("Health checking all workers...") checks = [] @@ -1825,10 +2034,12 @@ def _try_recover(self): raise RuntimeError( "Not enough healthy workers remain to continue.") - logger.warning("Recreating execution plan after failure") + logger.warning("Recreating execution plan after failure.") workers.reset(healthy_workers) - self.train_exec_impl = self.execution_plan( - workers, self.config, **self._kwargs_for_execution_plan()) + if self.train_exec_impl is not None: + if callable(self.execution_plan): + self.train_exec_impl = self.execution_plan( + workers, self.config, **self._kwargs_for_execution_plan()) @override(Trainable) def _export_model(self, export_formats: List[str], @@ -1887,6 +2098,11 @@ def __getstate__(self) -> dict: self.config.get("store_buffer_in_checkpoints"): state["local_replay_buffer"] = \ self.local_replay_buffer.get_state() + + if self.train_exec_impl is not None: + state["train_exec_impl"] = ( + self.train_exec_impl.shared_metrics.get().save()) + return state def __setstate__(self, state: dict): @@ -1916,6 +2132,10 @@ def __setstate__(self, state: dict): "`store_buffer_in_checkpoints` is False, but some replay " "data found in state!") + if self.train_exec_impl is not None: + self.train_exec_impl.shared_metrics.get().restore( + state["train_exec_impl"]) + @staticmethod def with_updates(**overrides) -> Type["Trainer"]: raise NotImplementedError( @@ -2016,6 +2236,9 @@ def _is_multi_agent(self): "You can specify a custom env as either a class " "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").") + def __repr__(self): + return self._name + @Deprecated(new="Trainer.evaluate", error=False) def _evaluate(self) -> dict: return self.evaluate() @@ -2024,5 +2247,6 @@ def _evaluate(self) -> dict: def compute_action(self, *args, **kwargs): return self.compute_single_action(*args, **kwargs) - def __repr__(self): - return self._name + @Deprecated(new="try_recover_from_step_attempt", error=False) + def _try_recover(self): + return self.try_recover_from_step_attempt() diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 450d91cefbbf..8a0f31c85b74 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -1,14 +1,9 @@ -import concurrent.futures -from functools import partial import logging from typing import Callable, Iterable, List, Optional, Type, Union from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.env.env_context import EnvContext from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches -from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep -from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.policy import Policy from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -18,41 +13,9 @@ logger = logging.getLogger(__name__) -def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict, - **kwargs): - assert len(kwargs) == 0, ( - "Default execution_plan does NOT take any additional parameters") - - # Collects experiences in parallel from multiple RolloutWorker actors. - rollouts = ParallelRollouts(workers, mode="bulk_sync") - - # Combine experiences batches until we hit `train_batch_size` in size. - # Then, train the policy on those experiences and update the workers. - train_op = rollouts.combine( - ConcatBatches( - min_batch_size=config["train_batch_size"], - count_steps_by=config["multiagent"]["count_steps_by"], - )) - - if config.get("simple_optimizer") is True: - train_op = train_op.for_each(TrainOneStep(workers)) - else: - train_op = train_op.for_each( - MultiGPUTrainOneStep( - workers=workers, - sgd_minibatch_size=config.get("sgd_minibatch_size", - config["train_batch_size"]), - num_sgd_iter=config.get("num_sgd_iter", 1), - num_gpus=config["num_gpus"], - shuffle_sequences=config.get("shuffle_sequences", False), - _fake_gpus=config["_fake_gpus"], - framework=config["framework"])) - - # Add on the standard episode reward, etc. metrics reporting. This returns - # a LocalIterator[metrics_dict] representing metrics for each train step. - return StandardMetricsReporting(train_op, workers, config) - - +# TODO: Deprecate Trainer template generated by this utility function. +# Instead, users should sub-class Trainer directly and override some of its +# methods, e.g. `Trainer.setup()`. @DeveloperAPI def build_trainer( name: str, @@ -70,7 +33,7 @@ def build_trainer( execution_plan: Optional[Union[Callable[ [WorkerSet, TrainerConfigDict], Iterable[ResultDict]], Callable[[ Trainer, WorkerSet, TrainerConfigDict - ], Iterable[ResultDict]]]] = default_execution_plan, + ], Iterable[ResultDict]]]] = None, allow_unknown_configs: bool = False, allow_unknown_subkeys: Optional[List[str]] = None, override_all_subkeys_if_type_changes: Optional[List[str]] = None, @@ -148,10 +111,9 @@ def _init(self, config: TrainerConfigDict, # No `get_policy_class` function. if get_policy_class is None: # Default_policy must be provided (unless in multi-agent mode, - # where each policy can have its own default policy class. + # where each policy can have its own default policy class). if not config["multiagent"]["policies"]: assert default_policy is not None - self._policy_class = default_policy # Query the function for a class to use. else: self._policy_class = get_policy_class(config) @@ -170,83 +132,17 @@ def _init(self, config: TrainerConfigDict, policy_class=self._policy_class, config=config, num_workers=self.config["num_workers"]) - self.execution_plan = execution_plan - self.train_exec_impl = execution_plan( + # If execution plan is not provided (None), the Trainer will use + # it's already existing default `execution_plan()` static method + # instead. + if execution_plan is not None: + self.execution_plan = execution_plan + self.train_exec_impl = self.execution_plan( self.workers, config, **self._kwargs_for_execution_plan()) if after_init: after_init(self) - @override(Trainer) - def step(self): - # self._iteration gets incremented after this function returns, - # meaning that e. g. the first time this function is called, - # self._iteration will be 0. - evaluate_this_iter = \ - self.config["evaluation_interval"] and \ - (self._iteration + 1) % self.config["evaluation_interval"] == 0 - - # No evaluation necessary, just run the next training iteration. - if not evaluate_this_iter: - step_results = next(self.train_exec_impl) - # We have to evaluate in this training iteration. - else: - # No parallelism. - if not self.config["evaluation_parallel_to_training"]: - step_results = next(self.train_exec_impl) - - # Kick off evaluation-loop (and parallel train() call, - # if requested). - # Parallel eval + training. - if self.config["evaluation_parallel_to_training"]: - with concurrent.futures.ThreadPoolExecutor() as executor: - train_future = executor.submit( - lambda: next(self.train_exec_impl)) - if self.config["evaluation_num_episodes"] == "auto": - - # Run at least one `evaluate()` (num_episodes_done - # must be > 0), even if the training is very fast. - def episodes_left_fn(num_episodes_done): - if num_episodes_done > 0 and \ - train_future.done(): - return 0 - else: - return self.config[ - "evaluation_num_workers"] - - evaluation_metrics = self.evaluate( - episodes_left_fn=episodes_left_fn) - else: - evaluation_metrics = self.evaluate() - # Collect the training results from the future. - step_results = train_future.result() - # Sequential: train (already done above), then eval. - else: - evaluation_metrics = self.evaluate() - - # Add evaluation results to train results. - assert isinstance(evaluation_metrics, dict), \ - "Trainer.evaluate() needs to return a dict." - step_results.update(evaluation_metrics) - - # Check `env_task_fn` for possible update of the env's task. - if self.config["env_task_fn"] is not None: - if not callable(self.config["env_task_fn"]): - raise ValueError( - "`env_task_fn` must be None or a callable taking " - "[train_results, env, env_ctx] as args!") - - def fn(env, env_context, task_fn): - new_task = task_fn(step_results, env, env_context) - cur_task = env.get_task() - if cur_task != new_task: - env.set_task(new_task) - - fn = partial(fn, task_fn=self.config["env_task_fn"]) - self.workers.foreach_env_with_context(fn) - - return step_results - @staticmethod @override(Trainer) def _validate_config(config: PartialTrainerConfigDict, @@ -262,19 +158,6 @@ def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) - @override(Trainer) - def __getstate__(self): - state = Trainer.__getstate__(self) - state["train_exec_impl"] = ( - self.train_exec_impl.shared_metrics.get().save()) - return state - - @override(Trainer) - def __setstate__(self, state): - Trainer.__setstate__(self, state) - self.train_exec_impl.shared_metrics.get().restore( - state["train_exec_impl"]) - @staticmethod @override(Trainer) def with_updates(**overrides) -> Type[Trainer]: diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index 49e03dd46360..2d115a59a975 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -164,8 +164,9 @@ def entropy_policy_gradient_loss(policy, model, dist_class, train_batch): EntropyPolicy = policy_cls.with_updates( loss_fn=entropy_policy_gradient_loss) - EntropyLossPG = PGTrainer.with_updates( - name="EntropyPG", get_policy_class=lambda _: EntropyPolicy) + class EntropyLossPG(PGTrainer): + def get_default_policy_class(self, config): + return EntropyPolicy run_heuristic_vs_learned(args, use_lstm=True, trainer=EntropyLossPG) diff --git a/rllib/tests/test_ignore_worker_failure.py b/rllib/tests/test_ignore_worker_failure.py index a49d068f4ec0..a9909fa95469 100644 --- a/rllib/tests/test_ignore_worker_failure.py +++ b/rllib/tests/test_ignore_worker_failure.py @@ -9,6 +9,21 @@ class FaultInjectEnv(gym.Env): + """Env that fails upon calling `step()`, but only for some remote workers. + + The worker indices that should produce the failure (a ValueError) can be + provided by a list (of ints) under the "bad_indices" key in the env's + config. + + Examples: + >>> from ray.rllib.env.env_context import EnvContext + >>> # This env will fail for workers 1 and 2 (not for the local worker + >>> # or any others with an index > 2). + >>> bad_env = FaultInjectEnv( + ... EnvContext({"bad_indices": [1, 2]}, + ... worker_index=1, num_workers=3)) + """ + def __init__(self, config): self.env = gym.make("CartPole-v0") self.action_space = self.env.action_space @@ -20,8 +35,8 @@ def reset(self): def step(self, action): if self.config.worker_index in self.config["bad_indices"]: - raise ValueError("This is a simulated error from {}".format( - self.config.worker_index)) + raise ValueError("This is a simulated error from " + f"worker-idx={self.config.worker_index}.") return self.env.step(action) @@ -42,7 +57,9 @@ def _do_test_fault_recover(self, alg, config): # Test fault handling config["num_workers"] = 2 config["ignore_worker_failures"] = True + # Make worker idx=1 fail. Other workers will be ok. config["env_config"] = {"bad_indices": [1]} + for _ in framework_iterator(config, frameworks=("torch", "tf")): a = agent_cls(config=config, env="fault_env") result = a.train() @@ -52,9 +69,11 @@ def _do_test_fault_recover(self, alg, config): def _do_test_fault_fatal(self, alg, config): register_env("fault_env", lambda c: FaultInjectEnv(c)) agent_cls = get_trainer_class(alg) + # Test raises real error when out of workers config["num_workers"] = 2 config["ignore_worker_failures"] = True + # Make both worker idx=1 and 2 fail. config["env_config"] = {"bad_indices": [1, 2]} for _ in framework_iterator(config, frameworks=("torch", "tf")): diff --git a/rllib/tests/test_placement_groups.py b/rllib/tests/test_placement_groups.py index 8a5f68635dd2..af5f0625369b 100644 --- a/rllib/tests/test_placement_groups.py +++ b/rllib/tests/test_placement_groups.py @@ -68,7 +68,10 @@ def test_overriding_default_resource_request(self): config["env"] = "CartPole-v0" config["framework"] = "tf" - class DefaultResourceRequest: + # Create a trainer with an overridden default_resource_request + # method that returns a PlacementGroupFactory. + + class MyTrainer(PGTrainer): @classmethod def default_resource_request(cls, config): head_bundle = {"CPU": 1, "GPU": 0} @@ -77,9 +80,6 @@ def default_resource_request(cls, config): [head_bundle, child_bundle, child_bundle], strategy=config["placement_strategy"]) - # Create a trainer with an overridden default_resource_request - # method that returns a PlacementGroupFactory. - MyTrainer = PGTrainer.with_updates(mixins=[DefaultResourceRequest]) tune.register_trainable("my_trainable", MyTrainer) global trial_executor diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 51c698169d90..07078c60d62e 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -422,8 +422,11 @@ def _test(what, method_to_test, obs_space, full_fetch, explore, timestep, if what is trainer: # Get the obs-space from Workers.env (not Policy) due to possible # pre-processor up front. - worker_set = getattr(trainer, "workers", - getattr(trainer, "_workers", None)) + worker_set = getattr(trainer, "workers") + # TODO: ES and ARS use `self._workers` instead of `self.workers` to + # store their rollout worker set. Change to `self.workers`. + if worker_set is None: + worker_set = getattr(trainer, "_workers", None) assert worker_set if isinstance(worker_set, list): obs_space = trainer.get_policy().observation_space