Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Speedup A3C up to 3x (new training_iteration function instead of execution_plan) and re-instate Pong learning test. #22126

Merged
merged 17 commits into from
Feb 8, 2022
Merged
69 changes: 34 additions & 35 deletions release/rllib_tests/learning_tests/hard_learning_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,40 @@ a2c-breakoutnoframeskip-v4:
[20000000, 0.000000000001],
]

# a3c-pongdeterministic-v4:
# env: PongDeterministic-v4
# run: A3C
# # Minimum reward and total ts (in given time_total_s) to pass this test.
# pass_criteria:
# episode_reward_mean: 18.0
# timesteps_total: 5000000
# stop:
# time_total_s: 3600
# config:
# ignore_worker_failures: true
# num_gpus: 0
# num_workers: 16
# rollout_fragment_length: 20
# vf_loss_coeff: 0.5
# entropy_coeff: 0.01
# gamma: 0.99
# grad_clip: 40.0
# lambda: 1.0
# lr: 0.0001
# observation_filter: NoFilter
# preprocessor_pref: rllib
# model:
# use_lstm: true
# conv_activation: elu
# dim: 42
# grayscale: true
# zero_mean: false
# # Reduced channel depth and kernel size from default.
# conv_filters: [
# [32, [3, 3], 2],
# [32, [3, 3], 2],
# [32, [3, 3], 2],
# [32, [3, 3], 2],
# ]
a3c-pongdeterministic-v4:
env: PongDeterministic-v4
run: A3C
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 18.0
timesteps_total: 5000000
stop:
time_total_s: 3600
config:
num_gpus: 0
num_workers: 16
rollout_fragment_length: 20
vf_loss_coeff: 0.5
entropy_coeff: 0.01
gamma: 0.99
grad_clip: 40.0
lambda: 1.0
lr: 0.0001
observation_filter: NoFilter
preprocessor_pref: rllib
model:
use_lstm: true
conv_activation: elu
dim: 42
grayscale: true
zero_mean: false
# Reduced channel depth and kernel size from default.
conv_filters: [
[32, [3, 3], 2],
[32, [3, 3], 2],
[32, [3, 3], 2],
[32, [3, 3], 2],
]

apex-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4
Expand Down
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ py_test(
py_test(
name = "test_a3c",
tags = ["team:ml", "trainers_dir"],
size = "medium",
size = "large",
srcs = ["agents/a3c/tests/test_a3c.py"]
)

Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/a3c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
# training with batch sizes much larger than can fit in GPU memory.
# To enable, set this to a value less than the train batch size.
"microbatch_size": None,
# Use `execution_plan` for A2C (no `training_iteration` implementation yet).
"_disable_execution_plan_api": False,
},
)

Expand Down
97 changes: 93 additions & 4 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import logging
from typing import Type
from typing import Any, Dict, Type

from ray.actor import ActorHandle
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
from ray.rllib.execution.rollout_ops import AsyncGradients
from ray.rllib.execution.train_ops import ApplyGradients
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER,
GRAD_WAIT_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.util.iter import LocalIterator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,11 +52,20 @@
"entropy_coeff": 0.01,
# Entropy coefficient schedule
"entropy_coeff_schedule": None,
# Min time per reporting
# Min time (in seconds) per reporting.
# This causes not every call to `training_iteration` to be reported,
# but to wait until n seconds have passed and then to summarize the
# thus far collected results.
"min_time_s_per_reporting": 5,
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,

# Use the Trainer's `training_iteration` function instead of `execution_plan`.
# Fixes a severe performance problem with A3C. Setting this to True leads to a
# speedup of up to 3x for a large number of workers and heavier
# gradient computations (e.g. ray/rllib/tuned_examples/a3c/pong-a3c.yaml)).
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down Expand Up @@ -74,6 +96,73 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
else:
return A3CTFPolicy

def training_iteration(self) -> ResultDict:
# Shortcut.
local_worker = self.workers.local_worker()

# Define the function executed in parallel by all RolloutWorkers to collect
# samples + compute and return gradients (and other information).

def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
"""Call sample() and compute_gradients() remotely on workers."""
samples = worker.sample()
grads, infos = worker.compute_gradients(samples)
return {
"grads": grads,
"infos": infos,
"agent_steps": samples.agent_steps(),
"env_steps": samples.env_steps(),
}

# Perform rollouts and gradient calculations asynchronously.
with self._timers[GRAD_WAIT_TIMER]:
# Results are a mapping from ActorHandle (RolloutWorker) to their
# returned gradient calculation results.
async_results: Dict[ActorHandle, Dict] = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_requests_in_flight,
actors=self.workers.remote_workers(),
ray_wait_timeout_s=0.0,
max_remote_requests_in_flight_per_actor=1,
remote_fn=sample_and_compute_grads,
)

# Loop through all fetched worker-computed gradients (if any)
# and apply them - one by one - to the local worker's model.
# After each apply step (one step per worker that returned some gradients),
# update that particular worker's weights.
global_vars = None
learner_info_builder = LearnerInfoBuilder(num_devices=1)
for worker, result in async_results.items():
# Apply gradients to local worker.
with self._timers[APPLY_GRADS_TIMER]:
print("Calling local-worker's `apply_gradients()` ...")
local_worker.apply_gradients(result["grads"])
self._timers[APPLY_GRADS_TIMER].push_units_processed(result["agent_steps"])

# Update all step counters.
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]

# Create current global vars.
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}

# Synch updated weights back to the particular worker.
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = local_worker.get_weights(local_worker.get_policies_to_train())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

worker.set_weights.remote(weights, global_vars)

learner_info_builder.add_learn_on_batch_results(result["infos"])

# Update global vars of the local worker.
if global_vars:
local_worker.set_global_vars(global_vars)

return learner_info_builder.finalize()

@staticmethod
@override(Trainer)
def execution_plan(
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/a3c/tests/test_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_a3c_compilation(self):
config["num_workers"] = 2
config["num_envs_per_worker"] = 2

num_iterations = 1
num_iterations = 2

# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
Expand Down
97 changes: 38 additions & 59 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics import (
TRAINING_ITERATION_TIMER,
NUM_ENV_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
Expand Down Expand Up @@ -904,19 +905,21 @@ def setup(self, config: PartialTrainerConfigDict):
# in each training iteration.
# This matches the behavior of using `build_trainer()`, which
# should no longer be used.
self.workers = self._make_workers(
self.workers = WorkerSet(
env_creator=self.env_creator,
validate_env=self.validate_env,
policy_class=self.get_default_policy_class(self.config),
config=self.config,
trainer_config=self.config,
num_workers=self.config["num_workers"],
local_worker=True,
)

# Function defining one single training iteration's behavior.
if self.config["_disable_execution_plan_api"]:
# Ensure remote workers are initially in sync with the
# TODO: Ensure remote workers are initially in sync with the
# local worker.
self.workers.sync_weights()
# self.workers.sync_weights()
pass # TODO: Uncommenting line above breaks tf2+eager_tracing for A3C.
# LocalIterator-creating "execution plan".
# Only call this once here to create `self.train_exec_impl`,
# which is a ray.util.iter.LocalIterator that will be `next`'d
Expand Down Expand Up @@ -1002,11 +1005,11 @@ def setup(self, config: PartialTrainerConfigDict):
# If evaluation_num_workers=0, use the evaluation set's local
# worker for evaluation, otherwise, use its remote workers
# (parallelized evaluation).
self.evaluation_workers: WorkerSet = self._make_workers(
self.evaluation_workers: WorkerSet = WorkerSet(
env_creator=self.env_creator,
validate_env=None,
policy_class=self.get_default_policy_class(self.config),
config=eval_config,
trainer_config=eval_config,
num_workers=self.config["evaluation_num_workers"],
# Don't even create a local worker if num_workers > 0.
local_worker=False,
Expand Down Expand Up @@ -2122,52 +2125,6 @@ def env_creator_from_classpath(env_context):
else:
return lambda env_config: None

@DeveloperAPI
def _make_workers(
self,
*,
env_creator: EnvCreator,
validate_env: Optional[Callable[[EnvType, EnvContext], None]],
policy_class: Type[Policy],
config: TrainerConfigDict,
num_workers: int,
local_worker: bool = True,
) -> WorkerSet:
"""Default factory method for a WorkerSet running under this Trainer.

Override this method by passing a custom `make_workers` into
`build_trainer`.

Args:
env_creator: A function that return and Env given an env
config.
validate_env: Optional callable to validate the generated
environment. The env to be checked is the one returned from
the env creator, which may be a (single, not-yet-vectorized)
gym.Env or your custom RLlib env type (e.g. MultiAgentEnv,
VectorEnv, BaseEnv, etc..).
policy_class: The Policy class to use for creating the policies
of the workers.
config: The Trainer's config.
num_workers: Number of remote rollout workers to create.
0 for local only.
local_worker: Whether to create a local (non @ray.remote) worker
in the returned set as well (default: True). If `num_workers`
is 0, always create a local worker.

Returns:
The created WorkerSet.
"""
return WorkerSet(
env_creator=env_creator,
validate_env=validate_env,
policy_class=policy_class,
trainer_config=config,
num_workers=num_workers,
local_worker=local_worker,
logdir=self.logdir,
)

def _sync_filters_if_needed(self, workers: WorkerSet):
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
FilterManager.synchronize(
Expand All @@ -2194,10 +2151,11 @@ def _sync_weights_to_workers(
worker_set.foreach_worker(lambda w: w.restore(ray.get(weights)))

def _exec_plan_or_training_iteration_fn(self):
if self.config["_disable_execution_plan_api"]:
results = self.training_iteration()
else:
results = next(self.train_exec_impl)
with self._timers[TRAINING_ITERATION_TIMER]:
if self.config["_disable_execution_plan_api"]:
results = self.training_iteration()
else:
results = next(self.train_exec_impl)
return results

@classmethod
Expand Down Expand Up @@ -3006,13 +2964,34 @@ def _compile_step_results(self, *, step_ctx, step_attempt_results=None):
def __repr__(self):
return type(self).__name__

@Deprecated(new="Trainer.compute_single_action()", error=False)
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)

@Deprecated(new="Trainer.evaluate()", error=True)
def _evaluate(self) -> dict:
return self.evaluate()

@Deprecated(new="Trainer.compute_single_action()", error=False)
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)
@Deprecated(new="construct WorkerSet(...) instance directly", error=False)
def _make_workers(
self,
*,
env_creator: EnvCreator,
validate_env: Optional[Callable[[EnvType, EnvContext], None]],
policy_class: Type[Policy],
config: TrainerConfigDict,
num_workers: int,
local_worker: bool = True,
) -> WorkerSet:
return WorkerSet(
env_creator=env_creator,
validate_env=validate_env,
policy_class=policy_class,
trainer_config=config,
num_workers=num_workers,
local_worker=local_worker,
logdir=self.logdir,
)

@Deprecated(new="Trainer.try_recover_from_step_attempt()", error=False)
def _try_recover(self):
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ def _init(self, config: TrainerConfigDict, env_creator: EnvCreator):
before_init(self)

# Creating all workers (excluding evaluation workers).
self.workers = self._make_workers(
self.workers = WorkerSet(
env_creator=env_creator,
validate_env=validate_env,
policy_class=self._policy_class,
config=config,
trainer_config=config,
num_workers=self.config["num_workers"],
)

Expand Down
Loading