Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] APPO Training iteration fn. #24545

Merged
merged 25 commits into from
May 17, 2022
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions rllib/agents/ppo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
STEPS_SAMPLED_COUNTER,
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
_get_shared_metrics,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
Expand Down Expand Up @@ -101,8 +100,6 @@ def __init__(self, trainer_class=None):
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.entropy_coeff_schedule = None

self._disable_execution_plan_api = False
# __sphinx_doc_end__
# fmt: on

Expand Down Expand Up @@ -163,45 +160,36 @@ def training(
return self


class UpdateTargetAndKL:
def __init__(self, workers, config):
self.workers = workers
self.config = config
self.update_kl = UpdateKL(workers)
self.target_update_freq = (
config["num_sgd_iter"] * config["minibatch_buffer_size"]
)

def __call__(self, fetches):
metrics = _get_shared_metrics()
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update > self.target_update_freq:
metrics.counters[NUM_TARGET_UPDATES] += 1
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update Target Network
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)
# Also update KL Coeff
if self.config["use_kl_loss"]:
self.update_kl(fetches)


class APPOTrainer(impala.ImpalaTrainer):
def __init__(self, config, *args, **kwargs):
# Before init: Add the update target and kl hook.
# This hook is called explicitly after each learner step in the
# execution setup for IMPALA.
config["after_train_step"] = UpdateTargetAndKL

super().__init__(config, *args, **kwargs)

self.config["after_train_step"] = self.update_target_and_kl
self.update_kl = UpdateKL(self.workers)

# After init: Initialize target net.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

def update_target_and_kl(self, train_results):
cur_ts = self._counters[STEPS_SAMPLED_COUNTER]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
target_update_freq = self.config["num_sgd_iter"] * self.config[
"minibatch_buffer_size"]
if cur_ts - last_update > target_update_freq:
avnishn marked this conversation as resolved.
Show resolved Hide resolved
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update Target Network.
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)

# Also update KL Coeff
if self.config["use_kl_loss"]:
self.update_kl(train_results)

@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
Expand Down