diff --git a/release/rllib_tests/learning_tests/hard_learning_tests.yaml b/release/rllib_tests/learning_tests/hard_learning_tests.yaml index af234388eb7c..c5ef1453b7ea 100644 --- a/release/rllib_tests/learning_tests/hard_learning_tests.yaml +++ b/release/rllib_tests/learning_tests/hard_learning_tests.yaml @@ -18,41 +18,40 @@ a2c-breakoutnoframeskip-v4: [20000000, 0.000000000001], ] -# a3c-pongdeterministic-v4: -# env: PongDeterministic-v4 -# run: A3C -# # Minimum reward and total ts (in given time_total_s) to pass this test. -# pass_criteria: -# episode_reward_mean: 18.0 -# timesteps_total: 5000000 -# stop: -# time_total_s: 3600 -# config: -# ignore_worker_failures: true -# num_gpus: 0 -# num_workers: 16 -# rollout_fragment_length: 20 -# vf_loss_coeff: 0.5 -# entropy_coeff: 0.01 -# gamma: 0.99 -# grad_clip: 40.0 -# lambda: 1.0 -# lr: 0.0001 -# observation_filter: NoFilter -# preprocessor_pref: rllib -# model: -# use_lstm: true -# conv_activation: elu -# dim: 42 -# grayscale: true -# zero_mean: false -# # Reduced channel depth and kernel size from default. -# conv_filters: [ -# [32, [3, 3], 2], -# [32, [3, 3], 2], -# [32, [3, 3], 2], -# [32, [3, 3], 2], -# ] +a3c-pongdeterministic-v4: + env: PongDeterministic-v4 + run: A3C + # Minimum reward and total ts (in given time_total_s) to pass this test. + pass_criteria: + episode_reward_mean: 18.0 + timesteps_total: 5000000 + stop: + time_total_s: 3600 + config: + num_gpus: 0 + num_workers: 16 + rollout_fragment_length: 20 + vf_loss_coeff: 0.5 + entropy_coeff: 0.01 + gamma: 0.99 + grad_clip: 40.0 + lambda: 1.0 + lr: 0.0001 + observation_filter: NoFilter + preprocessor_pref: rllib + model: + use_lstm: true + conv_activation: elu + dim: 42 + grayscale: true + zero_mean: false + # Reduced channel depth and kernel size from default. + conv_filters: [ + [32, [3, 3], 2], + [32, [3, 3], 2], + [32, [3, 3], 2], + [32, [3, 3], 2], + ] apex-breakoutnoframeskip-v4: env: BreakoutNoFrameskip-v4 diff --git a/rllib/BUILD b/rllib/BUILD index 5f31471f91dc..78b82d8020c1 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -609,7 +609,7 @@ py_test( py_test( name = "test_a3c", tags = ["team:ml", "trainers_dir"], - size = "medium", + size = "large", srcs = ["agents/a3c/tests/test_a3c.py"] ) diff --git a/rllib/agents/a3c/a2c.py b/rllib/agents/a3c/a2c.py index faf0d949331c..7bf12c41b9df 100644 --- a/rllib/agents/a3c/a2c.py +++ b/rllib/agents/a3c/a2c.py @@ -28,6 +28,8 @@ # training with batch sizes much larger than can fit in GPU memory. # To enable, set this to a value less than the train batch size. "microbatch_size": None, + # Use `execution_plan` for A2C (no `training_iteration` implementation yet). + "_disable_execution_plan_api": False, }, ) diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index 3a888543b1d9..7fd573faf887 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -1,15 +1,28 @@ import logging -from typing import Type +from typing import Any, Dict, Type +from ray.actor import ActorHandle from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.metric_ops import StandardMetricsReporting +from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests from ray.rllib.execution.rollout_ops import AsyncGradients from ray.rllib.execution.train_ops import ApplyGradients -from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import TrainerConfigDict +from ray.rllib.utils.metrics import ( + APPLY_GRADS_TIMER, + GRAD_WAIT_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_TRAINED, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_TRAINED, + SYNCH_WORKER_WEIGHTS_TIMER, +) +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder +from ray.rllib.utils.typing import ResultDict, TrainerConfigDict from ray.util.iter import LocalIterator logger = logging.getLogger(__name__) @@ -39,11 +52,20 @@ "entropy_coeff": 0.01, # Entropy coefficient schedule "entropy_coeff_schedule": None, - # Min time per reporting + # Min time (in seconds) per reporting. + # This causes not every call to `training_iteration` to be reported, + # but to wait until n seconds have passed and then to summarize the + # thus far collected results. "min_time_s_per_reporting": 5, # Workers sample async. Note that this increases the effective # rollout_fragment_length by up to 5x due to async buffering of batches. "sample_async": True, + + # Use the Trainer's `training_iteration` function instead of `execution_plan`. + # Fixes a severe performance problem with A3C. Setting this to True leads to a + # speedup of up to 3x for a large number of workers and heavier + # gradient computations (e.g. ray/rllib/tuned_examples/a3c/pong-a3c.yaml)). + "_disable_execution_plan_api": True, }) # __sphinx_doc_end__ # yapf: enable @@ -74,6 +96,73 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: else: return A3CTFPolicy + def training_iteration(self) -> ResultDict: + # Shortcut. + local_worker = self.workers.local_worker() + + # Define the function executed in parallel by all RolloutWorkers to collect + # samples + compute and return gradients (and other information). + + def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: + """Call sample() and compute_gradients() remotely on workers.""" + samples = worker.sample() + grads, infos = worker.compute_gradients(samples) + return { + "grads": grads, + "infos": infos, + "agent_steps": samples.agent_steps(), + "env_steps": samples.env_steps(), + } + + # Perform rollouts and gradient calculations asynchronously. + with self._timers[GRAD_WAIT_TIMER]: + # Results are a mapping from ActorHandle (RolloutWorker) to their + # returned gradient calculation results. + async_results: Dict[ActorHandle, Dict] = asynchronous_parallel_requests( + remote_requests_in_flight=self.remote_requests_in_flight, + actors=self.workers.remote_workers(), + ray_wait_timeout_s=0.0, + max_remote_requests_in_flight_per_actor=1, + remote_fn=sample_and_compute_grads, + ) + + # Loop through all fetched worker-computed gradients (if any) + # and apply them - one by one - to the local worker's model. + # After each apply step (one step per worker that returned some gradients), + # update that particular worker's weights. + global_vars = None + learner_info_builder = LearnerInfoBuilder(num_devices=1) + for worker, result in async_results.items(): + # Apply gradients to local worker. + with self._timers[APPLY_GRADS_TIMER]: + print("Calling local-worker's `apply_gradients()` ...") + local_worker.apply_gradients(result["grads"]) + self._timers[APPLY_GRADS_TIMER].push_units_processed(result["agent_steps"]) + + # Update all step counters. + self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] + self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] + self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] + self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] + + # Create current global vars. + global_vars = { + "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], + } + + # Synch updated weights back to the particular worker. + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + weights = local_worker.get_weights(local_worker.get_policies_to_train()) + worker.set_weights.remote(weights, global_vars) + + learner_info_builder.add_learn_on_batch_results(result["infos"]) + + # Update global vars of the local worker. + if global_vars: + local_worker.set_global_vars(global_vars) + + return learner_info_builder.finalize() + @staticmethod @override(Trainer) def execution_plan( diff --git a/rllib/agents/a3c/tests/test_a3c.py b/rllib/agents/a3c/tests/test_a3c.py index 53c7094879e0..de2758cf11fc 100644 --- a/rllib/agents/a3c/tests/test_a3c.py +++ b/rllib/agents/a3c/tests/test_a3c.py @@ -26,7 +26,7 @@ def test_a3c_compilation(self): config["num_workers"] = 2 config["num_envs_per_worker"] = 2 - num_iterations = 1 + num_iterations = 2 # Test against all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 8110673becd9..37c822b11b4a 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -75,6 +75,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.metrics import ( + TRAINING_ITERATION_TIMER, NUM_ENV_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED, @@ -904,19 +905,21 @@ def setup(self, config: PartialTrainerConfigDict): # in each training iteration. # This matches the behavior of using `build_trainer()`, which # should no longer be used. - self.workers = self._make_workers( + self.workers = WorkerSet( env_creator=self.env_creator, validate_env=self.validate_env, policy_class=self.get_default_policy_class(self.config), - config=self.config, + trainer_config=self.config, num_workers=self.config["num_workers"], + local_worker=True, ) # Function defining one single training iteration's behavior. if self.config["_disable_execution_plan_api"]: - # Ensure remote workers are initially in sync with the + # TODO: Ensure remote workers are initially in sync with the # local worker. - self.workers.sync_weights() + # self.workers.sync_weights() + pass # TODO: Uncommenting line above breaks tf2+eager_tracing for A3C. # LocalIterator-creating "execution plan". # Only call this once here to create `self.train_exec_impl`, # which is a ray.util.iter.LocalIterator that will be `next`'d @@ -1002,11 +1005,11 @@ def setup(self, config: PartialTrainerConfigDict): # If evaluation_num_workers=0, use the evaluation set's local # worker for evaluation, otherwise, use its remote workers # (parallelized evaluation). - self.evaluation_workers: WorkerSet = self._make_workers( + self.evaluation_workers: WorkerSet = WorkerSet( env_creator=self.env_creator, validate_env=None, policy_class=self.get_default_policy_class(self.config), - config=eval_config, + trainer_config=eval_config, num_workers=self.config["evaluation_num_workers"], # Don't even create a local worker if num_workers > 0. local_worker=False, @@ -2122,52 +2125,6 @@ def env_creator_from_classpath(env_context): else: return lambda env_config: None - @DeveloperAPI - def _make_workers( - self, - *, - env_creator: EnvCreator, - validate_env: Optional[Callable[[EnvType, EnvContext], None]], - policy_class: Type[Policy], - config: TrainerConfigDict, - num_workers: int, - local_worker: bool = True, - ) -> WorkerSet: - """Default factory method for a WorkerSet running under this Trainer. - - Override this method by passing a custom `make_workers` into - `build_trainer`. - - Args: - env_creator: A function that return and Env given an env - config. - validate_env: Optional callable to validate the generated - environment. The env to be checked is the one returned from - the env creator, which may be a (single, not-yet-vectorized) - gym.Env or your custom RLlib env type (e.g. MultiAgentEnv, - VectorEnv, BaseEnv, etc..). - policy_class: The Policy class to use for creating the policies - of the workers. - config: The Trainer's config. - num_workers: Number of remote rollout workers to create. - 0 for local only. - local_worker: Whether to create a local (non @ray.remote) worker - in the returned set as well (default: True). If `num_workers` - is 0, always create a local worker. - - Returns: - The created WorkerSet. - """ - return WorkerSet( - env_creator=env_creator, - validate_env=validate_env, - policy_class=policy_class, - trainer_config=config, - num_workers=num_workers, - local_worker=local_worker, - logdir=self.logdir, - ) - def _sync_filters_if_needed(self, workers: WorkerSet): if self.config.get("observation_filter", "NoFilter") != "NoFilter": FilterManager.synchronize( @@ -2194,10 +2151,11 @@ def _sync_weights_to_workers( worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) def _exec_plan_or_training_iteration_fn(self): - if self.config["_disable_execution_plan_api"]: - results = self.training_iteration() - else: - results = next(self.train_exec_impl) + with self._timers[TRAINING_ITERATION_TIMER]: + if self.config["_disable_execution_plan_api"]: + results = self.training_iteration() + else: + results = next(self.train_exec_impl) return results @classmethod @@ -3006,13 +2964,34 @@ def _compile_step_results(self, *, step_ctx, step_attempt_results=None): def __repr__(self): return type(self).__name__ + @Deprecated(new="Trainer.compute_single_action()", error=False) + def compute_action(self, *args, **kwargs): + return self.compute_single_action(*args, **kwargs) + @Deprecated(new="Trainer.evaluate()", error=True) def _evaluate(self) -> dict: return self.evaluate() - @Deprecated(new="Trainer.compute_single_action()", error=False) - def compute_action(self, *args, **kwargs): - return self.compute_single_action(*args, **kwargs) + @Deprecated(new="construct WorkerSet(...) instance directly", error=False) + def _make_workers( + self, + *, + env_creator: EnvCreator, + validate_env: Optional[Callable[[EnvType, EnvContext], None]], + policy_class: Type[Policy], + config: TrainerConfigDict, + num_workers: int, + local_worker: bool = True, + ) -> WorkerSet: + return WorkerSet( + env_creator=env_creator, + validate_env=validate_env, + policy_class=policy_class, + trainer_config=config, + num_workers=num_workers, + local_worker=local_worker, + logdir=self.logdir, + ) @Deprecated(new="Trainer.try_recover_from_step_attempt()", error=False) def _try_recover(self): diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 6db7784b194a..7c3bffba20ab 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -151,11 +151,11 @@ def _init(self, config: TrainerConfigDict, env_creator: EnvCreator): before_init(self) # Creating all workers (excluding evaluation workers). - self.workers = self._make_workers( + self.workers = WorkerSet( env_creator=env_creator, validate_env=validate_env, policy_class=self._policy_class, - config=config, + trainer_config=config, num_workers=self.config["num_workers"], ) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 1d84ac0c0aa6..7d9f076f484d 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -37,6 +37,7 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.filter import Filter +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.space_utils import clip_action, unsquash_action, unbatch from ray.rllib.utils.typing import ( @@ -52,14 +53,12 @@ ) if TYPE_CHECKING: + from gym.envs.classic_control.rendering import SimpleImageViewer from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.evaluation.rollout_worker import RolloutWorker - from ray.rllib.utils import try_import_tf - - _, tf, _ = try_import_tf() - from gym.envs.classic_control.rendering import SimpleImageViewer +tf1, tf, _ = try_import_tf() logger = logging.getLogger(__name__) PolicyEvalData = namedtuple( @@ -453,6 +452,14 @@ def run(self): raise e def _run(self): + # We are in a thread: Switch on eager execution mode, iff framework==tf2|tfe. + if ( + tf1 + and self.worker.policy_config.get("framework", "tf") in ["tf2", "tfe"] + and not tf1.executing_eagerly() + ): + tf1.enable_eager_execution() + if self.blackhole_outputs: queue_putter = lambda x: None extra_batches_putter = lambda x: None diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 9999a4eed5ce..70b07434f1cb 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -45,9 +45,9 @@ @DeveloperAPI class WorkerSet: - """Set of RolloutWorkers with n @ray.remote workers and one local worker. + """Set of RolloutWorkers with n @ray.remote workers and zero or one local worker. - Where n may be 0. + Where: n >= 0. """ def __init__( diff --git a/rllib/execution/parallel_requests.py b/rllib/execution/parallel_requests.py index 10b57e76f7c4..0a6becd21b2c 100644 --- a/rllib/execution/parallel_requests.py +++ b/rllib/execution/parallel_requests.py @@ -14,7 +14,7 @@ def asynchronous_parallel_requests( actors: List[ActorHandle], ray_wait_timeout_s: Optional[float] = None, max_remote_requests_in_flight_per_actor: int = 2, - remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None, + remote_fn: Optional[Callable[[Any, Optional[Any], Optional[Any]], Any]] = None, remote_args: Optional[List[List[Any]]] = None, remote_kwargs: Optional[List[Dict[str, Any]]] = None, ) -> Dict[ActorHandle, Any]: diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index d1f1d81693af..cd41cc6feafa 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -168,7 +168,7 @@ def compute_actions_from_input_dict( self._traced_compute_actions_helper = True # Now that the helper method is traced, call super's - # apply_gradients (which will call the traced helper). + # `compute_actions_from_input_dict()` (which will call the traced helper). return super(TracedEagerPolicy, self).compute_actions_from_input_dict( input_dict=input_dict, explore=explore, @@ -214,7 +214,7 @@ def compute_gradients(self, samples: SampleBatch) -> ModelGradients: self._traced_compute_gradients_helper = True # Now that the helper method is traced, call super's - # apply_gradients (which will call the traced helper). + # `compute_gradients()` (which will call the traced helper). return super(TracedEagerPolicy, self).compute_gradients(samples) @check_too_many_retraces @@ -222,8 +222,11 @@ def compute_gradients(self, samples: SampleBatch) -> ModelGradients: def apply_gradients(self, grads: ModelGradients) -> None: """Traced version of Policy.apply_gradients.""" + print("Eager-tracing-policy: inside `apply_gradients()`.") + # Create a traced version of `self._apply_gradients_helper`. if self._traced_apply_gradients_helper is False and not self._no_tracing: + print("... tracing") self._apply_gradients_helper = convert_eager_inputs( tf.function( super(TracedEagerPolicy, self)._apply_gradients_helper, @@ -234,7 +237,8 @@ def apply_gradients(self, grads: ModelGradients) -> None: self._traced_apply_gradients_helper = True # Now that the helper method is traced, call super's - # apply_gradients (which will call the traced helper). + # `apply_gradients()` (which will call the traced helper). + print("... calling super") return super(TracedEagerPolicy, self).apply_gradients(grads) TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced" diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 3e8471feebad..022451d94b64 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -174,6 +174,14 @@ def agent_steps(self) -> int: """ return len(self) + @PublicAPI + def env_steps(self) -> int: + """Returns the same as len(self) (number of steps in this batch). + + To make this compatible with `MultiAgentBatch.env_steps()`. + """ + return len(self) + @staticmethod @PublicAPI def concat_samples( diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 27caf30c33f7..7cd3411edffa 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -9,6 +9,7 @@ NUM_TARGET_UPDATES = "num_target_updates" # Performance timers (keys for Trainer._timers or metrics.timers). +TRAINING_ITERATION_TIMER = "training_iteration" APPLY_GRADS_TIMER = "apply_grad" COMPUTE_GRADS_TIMER = "compute_grads" SYNCH_WORKER_WEIGHTS_TIMER = "synch_weights" diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 5ec35a8b0126..3fe2064991aa 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -405,7 +405,7 @@ def _test( isinstance(action_space, Box) and not unsquash and what.config.get("normalize_actions") - and np.any(np.abs(action) > 3.0) + and np.any(np.abs(action) > 15.0) ): raise ValueError( f"Returned action ({action}) of trainer/policy {what} "