Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Eager execution for centralized critic example, fix simple optimizer for multiagent #5683

Merged
merged 6 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions doc/source/rllib-toc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,46 @@ RLlib Table of Contents

Training APIs
-------------
* `Command-line <rllib-training.html>`__
* `Configuration <rllib-training.html#configuration>`__
* `Command-line <rllib-training.html>`__
* `Configuration <rllib-training.html#configuration>`__

- `Specifying Parameters <rllib-training.html#specifying-parameters>`__

- `Specifying Resources <rllib-training.html#specifying-resources>`__

- `Common Parameters <rllib-training.html#common-parameters>`__

- `Tuned Examples <rllib-training.html#tuned-examples>`__

* `Python API <rllib-training.html#python-api>`__
* `Python API <rllib-training.html#python-api>`__

- `Custom Training Workflows <rllib-training.html#custom-training-workflows>`__

- `Accessing Policy State <rllib-training.html#accessing-policy-state>`__

- `Accessing Model State <rllib-training.html#accessing-model-state>`__

- `Global Coordination <rllib-training.html#global-coordination>`__

- `Callbacks and Custom Metrics <rllib-training.html#callbacks-and-custom-metrics>`__

- `Rewriting Trajectories <rllib-training.html#rewriting-trajectories>`__

- `Curriculum Learning <rllib-training.html#curriculum-learning>`__

* `Debugging <rllib-training.html#debugging>`__
* `Debugging <rllib-training.html#debugging>`__

- `Gym Monitor <rllib-training.html#gym-monitor>`__

- `Eager Mode <rllib-training.html#eager-mode>`__

- `Episode Traces <rllib-training.html#episode-traces>`__

- `Log Verbosity <rllib-training.html#log-verbosity>`__

- `Stack Traces <rllib-training.html#stack-traces>`__

* `REST API <rllib-training.html#rest-api>`__
* `REST API <rllib-training.html#rest-api>`__

Environments
------------
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def choose_policy_optimizer(workers, config):
workers,
num_sgd_iter=config["num_sgd_iter"],
train_batch_size=config["train_batch_size"],
sgd_minibatch_size=config["sgd_minibatch_size"])
sgd_minibatch_size=config["sgd_minibatch_size"],
standardize_fields=["advantages"])

return LocalMultiGPUOptimizer(
workers,
Expand Down
28 changes: 10 additions & 18 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils.tf_ops import make_tf_callable
from ray.rllib.utils import try_import_tf

tf = try_import_tf()
Expand Down Expand Up @@ -83,21 +84,11 @@ def value_function(self):


class CentralizedValueMixin(object):
"""Add methods to evaluate the central value function from the model."""
"""Add method to evaluate the central value function from the model."""

def __init__(self):
self.central_value_function = self.model.central_value_function(
self.get_placeholder(SampleBatch.CUR_OBS),
self.get_placeholder(OPPONENT_OBS),
self.get_placeholder(OPPONENT_ACTION))

def compute_central_vf(self, obs, opponent_obs, opponent_actions):
feed_dict = {
self.get_placeholder(SampleBatch.CUR_OBS): obs,
self.get_placeholder(OPPONENT_OBS): opponent_obs,
self.get_placeholder(OPPONENT_ACTION): opponent_actions,
}
return self.get_session().run(self.central_value_function, feed_dict)
self.compute_central_vf = make_tf_callable(self.get_session())(
self.model.central_value_function)


# Grabs the opponent obs/act and includes it in the experience train_batch,
Expand Down Expand Up @@ -144,6 +135,9 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):

logits, state = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
policy.central_value_out = policy.model.central_value_function(
train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS],
train_batch[OPPONENT_ACTION])

policy.loss_obj = PPOLoss(
policy.action_space,
Expand All @@ -156,7 +150,7 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):
train_batch[ACTION_LOGP],
train_batch[SampleBatch.VF_PREDS],
action_dist,
policy.central_value_function,
policy.central_value_out,
policy.kl_coeff,
tf.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool),
entropy_coeff=policy.entropy_coeff,
Expand All @@ -175,17 +169,14 @@ def setup_mixins(policy, obs_space, action_space, config):
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
# hack: put in a noop VF so some of the inherited PPO code runs
policy.value_function = tf.zeros(
tf.shape(policy.get_placeholder(SampleBatch.CUR_OBS))[0])


def central_vf_stats(policy, train_batch, grads):
# Report the explained variance of the central value function.
return {
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.central_value_function),
policy.central_value_out),
}


Expand Down Expand Up @@ -214,6 +205,7 @@ def central_vf_stats(policy, train_batch, grads):
config={
"env": TwoStepGame,
"batch_mode": "complete_episodes",
"eager": False,
"num_workers": 0,
"multiagent": {
"policies": {
Expand Down
12 changes: 11 additions & 1 deletion rllib/examples/custom_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.postprocessing import discount
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf

Expand All @@ -20,13 +21,22 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch["actions"]) * train_batch["rewards"])
action_dist.logp(train_batch["actions"]) * train_batch["advantages"])


def calculate_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should try to fetch gamma from policy config instead right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for the example.

return sample_batch


# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
name="MyTFPolicy",
loss_fn=policy_gradient_loss,
postprocess_fn=calculate_advantages,
)

# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
Expand Down
60 changes: 48 additions & 12 deletions rllib/optimizers/sync_samples_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

import logging
import random
from collections import defaultdict

import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.optimizers.multi_gpu_optimizer import _averaged
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
Expand All @@ -29,17 +32,22 @@ def __init__(self,
workers,
num_sgd_iter=1,
train_batch_size=1,
sgd_minibatch_size=0):
sgd_minibatch_size=0,
standardize_fields=frozenset([])):
PolicyOptimizer.__init__(self, workers)

self.update_weights_timer = TimerStat()
self.standardize_fields = standardize_fields
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to froenzeset

self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.num_sgd_iter = num_sgd_iter
self.sgd_minibatch_size = sgd_minibatch_size
self.train_batch_size = train_batch_size
self.learner_stats = {}
self.policies = dict(self.workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
logger.debug("Policies to train: {}".format(self.policies))

@override(PolicyOptimizer)
def step(self):
Expand All @@ -63,16 +71,44 @@ def step(self):
samples = SampleBatch.concat_samples(samples)
self.sample_timer.push_units_processed(samples.count)

with self.grad_timer:
for i in range(self.num_sgd_iter):
for minibatch in self._minibatches(samples):
fetches = self.workers.local_worker().learn_on_batch(
minibatch)
self.learner_stats = get_learner_stats(fetches)
if self.num_sgd_iter > 1:
logger.debug("{} {}".format(i, fetches))
self.grad_timer.push_units_processed(samples.count)
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
}, samples.count)

fetches = {}
with self.grad_timer:
for policy_id, policy in self.policies.items():
if policy_id not in samples.policy_batches:
continue

batch = samples.policy_batches[policy_id]
for field in self.standardize_fields:
value = batch[field]
standardized = (value - value.mean()) / max(
1e-4, value.std())
batch[field] = standardized

for i in range(self.num_sgd_iter):
iter_extra_fetches = defaultdict(list)
for minibatch in self._minibatches(batch):
batch_fetches = (
self.workers.local_worker().learn_on_batch(
MultiAgentBatch({
policy_id: minibatch
}, minibatch.count)))[policy_id]
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i,
_averaged(iter_extra_fetches)))
fetches[policy_id] = _averaged(iter_extra_fetches)

self.grad_timer.push_units_processed(samples.count)
if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
self.learner_stats = fetches[DEFAULT_POLICY_ID]
else:
self.learner_stats = fetches
self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
return self.learner_stats
Expand Down
9 changes: 8 additions & 1 deletion rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def postprocess_trajectory(self,
episode=None):
assert tf.executing_eagerly()
if postprocess_fn:
return postprocess_fn(self, samples)
return postprocess_fn(self, samples, other_agent_batches,
episode)
else:
return samples

Expand Down Expand Up @@ -224,6 +225,12 @@ def num_state_tensors(self):
def get_session(self):
return None # None implies eager

def get_placeholder(self, ph):
raise ValueError(
"get_placeholder() is not allowed in eager mode. Try using "
"rllib.utils.tf_ops.make_tf_callable() to write "
"functions that work in both graph and eager mode.")

def loss_initialized(self):
return self._loss_initialized

Expand Down