Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] APPO Training iteration fn. #24545

Merged
merged 25 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rllib/agents/ddpg/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
"min_sample_timesteps_per_reporting": 25000,
"worker_side_prioritization": True,
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/ddpg/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"tau": 5e-3,
"train_batch_size": 100,
"use_huber": False,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 0,
"num_workers": 0,
"num_gpus_per_worker": 0,
Expand Down
7 changes: 2 additions & 5 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
# Minimum env sampling timesteps to accumulate within a single `train()` call.
# This value does not affect learning, only the number of times
Expand Down Expand Up @@ -571,13 +572,9 @@ def wait_on_replay_actors(timeout: float) -> None:
except queue.Full:
break

def update_replay_sample_priority(self) -> int:
def update_replay_sample_priority(self) -> None:
"""Update the priorities of the sample batches with new priorities that are
computed by the learner thread.

Returns:
The number of samples trained by the learner thread since the last
training iteration.
"""
num_samples_trained_this_itr = 0
for _ in range(self.learner_thread.outqueue.qsize()):
Expand Down
10 changes: 7 additions & 3 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def training_iteration(self) -> ResultDict:
- Sample training batch (MultiAgentBatch) from replay buffer.
- Learn on training batch.
- Update remote workers' new policy weights.
- Update target network every target_network_update_freq steps.
- Update target network every `target_network_update_freq` sample steps.
- Return all collected metrics for the iteration.

Returns:
Expand Down Expand Up @@ -403,8 +403,12 @@ def training_iteration(self) -> ResultDict:
train_results,
)

# Update target network every `target_network_update_freq` steps.
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED
if self._by_agent_steps
else NUM_ENV_STEPS_SAMPLED
]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
to_update = self.workers.local_worker().get_policies_to_train()
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/dqn/learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def step(self):
)
self.learner_info = learner_info_builder.finalize()
self.grad_timer.push_units_processed(ma_batch.count)
# Put tuple: replay_actor, prio-dict, env-steps, and agent-steps into
# the queue.
self.outqueue.put(
(replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps())
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/r2d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
# if `use_h_function`=True.
"h_function_epsilon": 1e-3,

# Update the target network every `target_network_update_freq` steps.
# Update the target network every `target_network_update_freq` sample steps.
"target_network_update_freq": 2500,

# Deprecated keys:
Expand Down
9 changes: 5 additions & 4 deletions rllib/agents/dqn/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
TARGET_NET_UPDATE_TIMER,
Expand Down Expand Up @@ -308,7 +307,7 @@ def training_iteration(self) -> ResultDict:
- Store new samples in the replay buffer.
- Sample one training MultiAgentBatch from the replay buffer.
- Learn on the training batch.
- Update the target network every `target_network_update_freq` steps.
- Update the target network every `target_network_update_freq` sample steps.
- Return all collected training metrics for the iteration.

Returns:
Expand Down Expand Up @@ -355,8 +354,10 @@ def training_iteration(self) -> ResultDict:
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()

# Update target network every `target_network_update_freq` steps.
cur_ts = self._counters[NUM_ENV_STEPS_TRAINED]
# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
with self._timers[TARGET_NET_UPDATE_TIMER]:
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/tests/test_apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _step_n_times(trainer, n: int):
for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")

lr = _step_n_times(trainer, 5) # 10 timesteps
lr = _step_n_times(trainer, 5) # 50 timesteps
# Close to 0.2
self.assertGreaterEqual(lr, 0.1)

Expand Down
33 changes: 20 additions & 13 deletions rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
)

# from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
Expand Down Expand Up @@ -512,13 +513,6 @@ def setup(self, config: PartialTrainerConfigDict):
] = defaultdict(set)

if self.config["_disable_execution_plan_api"]:
# Setup after_train_step callback.
self._after_train_step = lambda *a, **k: None
if self.config["after_train_step"]:
self._after_train_step = self.config["after_train_step"](
self.workers, self.config
)

# Create extra aggregation workers and assign each rollout worker to
# one of them.
self.batches_to_place_on_learner = []
Expand Down Expand Up @@ -587,15 +581,24 @@ def training_iteration(self) -> ResultDict:

self.concatenate_batches_and_pre_queue(batch)
self.place_processed_samples_on_learner_queue()
learner_results = self.process_trained_results()
train_results = self.process_trained_results()

self.update_workers_if_necessary()

# Callback for APPO to use to update KL, target network periodically.
# The input to the callback is the learner fetches dict.
self._after_train_step(learner_results)
self.after_train_step(train_results)

return train_results

return learner_results
def after_train_step(self, train_results: ResultDict) -> None:
"""Called by the training_iteration method after each train step.

Args:
train_results: The train results dict.
"""
# By default, do nothing.
pass

@staticmethod
@override(Trainer)
Expand Down Expand Up @@ -766,15 +769,18 @@ def place_processed_samples_on_learner_queue(self) -> None:
def process_trained_results(self) -> ResultDict:
# Get learner outputs/stats from output queue.
learner_infos = []
num_env_steps_trained = 0
num_agent_steps_trained = 0

for _ in range(self._learner_thread.outqueue.qsize()):
if self._learner_thread.is_alive():
(
num_trained_samples,
env_steps,
agent_steps,
learner_results,
) = self._learner_thread.outqueue.get(timeout=0.001)
num_agent_steps_trained += num_trained_samples
num_env_steps_trained += env_steps
num_agent_steps_trained += agent_steps
if learner_results:
learner_infos.append(learner_results)
else:
Expand All @@ -783,6 +789,7 @@ def process_trained_results(self) -> ResultDict:

# Update the steps trained counters.
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained
self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained

return learner_info
Expand Down Expand Up @@ -845,7 +852,7 @@ def process_experiences_tree_aggregation(

def update_workers_if_necessary(self) -> None:
# Only need to update workers if there are remote workers.
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]}
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]}
self._counters["steps_since_broadcast"] += 1
if (
self.workers.remote_workers()
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/maddpg/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"critic_lr": 1e-2,
# Learning rate for the actor (policy) optimizer.
"actor_lr": 1e-2,
# Update the target network every `target_network_update_freq` steps.
# Update the target network every `target_network_update_freq` sample steps.
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.01,
Expand Down
89 changes: 48 additions & 41 deletions rllib/agents/ppo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@
"""
from typing import Optional, Type

from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
from ray.rllib.agents.ppo.ppo import UpdateKL
from ray.rllib.agents import impala
from ray.rllib.policy.policy import Policy
from ray.rllib.execution.common import (
STEPS_SAMPLED_COUNTER,
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
_get_shared_metrics,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)


class APPOConfig(impala.ImpalaConfig):
Expand Down Expand Up @@ -101,8 +104,6 @@ def __init__(self, trainer_class=None):
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.entropy_coeff_schedule = None

self._disable_execution_plan_api = False
# __sphinx_doc_end__
# fmt: on

Expand Down Expand Up @@ -163,51 +164,57 @@ def training(
return self


class UpdateTargetAndKL:
def __init__(self, workers, config):
self.workers = workers
self.config = config
self.update_kl = UpdateKL(workers)
self.target_update_freq = (
config["num_sgd_iter"] * config["minibatch_buffer_size"]
)

def __call__(self, fetches):
metrics = _get_shared_metrics()
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update > self.target_update_freq:
metrics.counters[NUM_TARGET_UPDATES] += 1
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update Target Network
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)
# Also update KL Coeff
if self.config["use_kl_loss"]:
self.update_kl(fetches)


class APPOTrainer(impala.ImpalaTrainer):
def __init__(self, config, *args, **kwargs):
# Before init: Add the update target and kl hook.
# This hook is called explicitly after each learner step in the
# execution setup for IMPALA.
config["after_train_step"] = UpdateTargetAndKL

super().__init__(config, *args, **kwargs)

self.update_kl = UpdateKL(self.workers)

# After init: Initialize target net.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

@override(impala.ImpalaTrainer)
def after_train_step(self, train_results: ResultDict) -> None:
"""Updates the target network and the KL coefficient for the APPO-loss.

This method is called from within the `training_iteration` method after each
train update.

The target network update frequency is calculated automatically by the product
of `num_sgd_iter` setting (usually 1 for APPO) and `minibatch_buffer_size`.

Args:
train_results: The results dict collected during the most recent
training step.
"""
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
target_update_freq = (
self.config["num_sgd_iter"] * self.config["minibatch_buffer_size"]
)
if cur_ts - last_update > target_update_freq:
avnishn marked this conversation as resolved.
Show resolved Hide resolved
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update our target network.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

# Also update the KL-coefficient for the APPO loss, if necessary.
if self.config["use_kl_loss"]:
self.update_kl(train_results)

@classmethod
@override(Trainer)
@override(impala.ImpalaTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APPOConfig().to_dict()

@override(Trainer)
@override(impala.ImpalaTrainer)
def get_default_policy_class(
self, config: PartialTrainerConfigDict
) -> Optional[Type[Policy]]:
Expand Down
Loading