diff --git a/doc/source/rllib-toc.rst b/doc/source/rllib-toc.rst index cdf225fe128a..0d081012e563 100644 --- a/doc/source/rllib-toc.rst +++ b/doc/source/rllib-toc.rst @@ -3,33 +3,46 @@ RLlib Table of Contents Training APIs ------------- -* `Command-line `__ -* `Configuration `__ +* `Command-line `__ +* `Configuration `__ - `Specifying Parameters `__ + - `Specifying Resources `__ + - `Common Parameters `__ + - `Tuned Examples `__ -* `Python API `__ +* `Python API `__ - `Custom Training Workflows `__ + - `Accessing Policy State `__ + - `Accessing Model State `__ + - `Global Coordination `__ + - `Callbacks and Custom Metrics `__ + - `Rewriting Trajectories `__ + - `Curriculum Learning `__ -* `Debugging `__ +* `Debugging `__ - `Gym Monitor `__ + - `Eager Mode `__ + - `Episode Traces `__ + - `Log Verbosity `__ + - `Stack Traces `__ -* `REST API `__ +* `REST API `__ Environments ------------ diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index 8cdfb4ae2c87..3df44b412773 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -74,7 +74,8 @@ def choose_policy_optimizer(workers, config): workers, num_sgd_iter=config["num_sgd_iter"], train_batch_size=config["train_batch_size"], - sgd_minibatch_size=config["sgd_minibatch_size"]) + sgd_minibatch_size=config["sgd_minibatch_size"], + standardize_fields=["advantages"]) return LocalMultiGPUOptimizer( workers, diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index d2bebe301a4f..6c8a450c4648 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -32,6 +32,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils.tf_ops import make_tf_callable from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -83,21 +84,11 @@ def value_function(self): class CentralizedValueMixin(object): - """Add methods to evaluate the central value function from the model.""" + """Add method to evaluate the central value function from the model.""" def __init__(self): - self.central_value_function = self.model.central_value_function( - self.get_placeholder(SampleBatch.CUR_OBS), - self.get_placeholder(OPPONENT_OBS), - self.get_placeholder(OPPONENT_ACTION)) - - def compute_central_vf(self, obs, opponent_obs, opponent_actions): - feed_dict = { - self.get_placeholder(SampleBatch.CUR_OBS): obs, - self.get_placeholder(OPPONENT_OBS): opponent_obs, - self.get_placeholder(OPPONENT_ACTION): opponent_actions, - } - return self.get_session().run(self.central_value_function, feed_dict) + self.compute_central_vf = make_tf_callable(self.get_session())( + self.model.central_value_function) # Grabs the opponent obs/act and includes it in the experience train_batch, @@ -144,6 +135,9 @@ def loss_with_central_critic(policy, model, dist_class, train_batch): logits, state = model.from_batch(train_batch) action_dist = dist_class(logits, model) + policy.central_value_out = policy.model.central_value_function( + train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS], + train_batch[OPPONENT_ACTION]) policy.loss_obj = PPOLoss( policy.action_space, @@ -156,7 +150,7 @@ def loss_with_central_critic(policy, model, dist_class, train_batch): train_batch[ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, - policy.central_value_function, + policy.central_value_out, policy.kl_coeff, tf.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool), entropy_coeff=policy.entropy_coeff, @@ -175,9 +169,6 @@ def setup_mixins(policy, obs_space, action_space, config): EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) - # hack: put in a noop VF so some of the inherited PPO code runs - policy.value_function = tf.zeros( - tf.shape(policy.get_placeholder(SampleBatch.CUR_OBS))[0]) def central_vf_stats(policy, train_batch, grads): @@ -185,7 +176,7 @@ def central_vf_stats(policy, train_batch, grads): return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], - policy.central_value_function), + policy.central_value_out), } @@ -214,6 +205,7 @@ def central_vf_stats(policy, train_batch, grads): config={ "env": TwoStepGame, "batch_mode": "complete_episodes", + "eager": False, "num_workers": 0, "multiagent": { "policies": { diff --git a/rllib/examples/custom_tf_policy.py b/rllib/examples/custom_tf_policy.py index a3e5698e981f..fbde9201f55d 100644 --- a/rllib/examples/custom_tf_policy.py +++ b/rllib/examples/custom_tf_policy.py @@ -7,6 +7,7 @@ import ray from ray import tune from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.evaluation.postprocessing import discount from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils import try_import_tf @@ -20,13 +21,22 @@ def policy_gradient_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) action_dist = dist_class(logits, model) return -tf.reduce_mean( - action_dist.logp(train_batch["actions"]) * train_batch["rewards"]) + action_dist.logp(train_batch["actions"]) * train_batch["advantages"]) + + +def calculate_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99) + return sample_batch # MyTFPolicy = build_tf_policy( name="MyTFPolicy", loss_fn=policy_gradient_loss, + postprocess_fn=calculate_advantages, ) # diff --git a/rllib/optimizers/sync_samples_optimizer.py b/rllib/optimizers/sync_samples_optimizer.py index 1679e8c2caaf..a6c945a8a984 100644 --- a/rllib/optimizers/sync_samples_optimizer.py +++ b/rllib/optimizers/sync_samples_optimizer.py @@ -4,11 +4,14 @@ import logging import random +from collections import defaultdict import ray -from ray.rllib.evaluation.metrics import get_learner_stats +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY +from ray.rllib.optimizers.multi_gpu_optimizer import _averaged from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat @@ -29,10 +32,12 @@ def __init__(self, workers, num_sgd_iter=1, train_batch_size=1, - sgd_minibatch_size=0): + sgd_minibatch_size=0, + standardize_fields=frozenset([])): PolicyOptimizer.__init__(self, workers) self.update_weights_timer = TimerStat() + self.standardize_fields = standardize_fields self.sample_timer = TimerStat() self.grad_timer = TimerStat() self.throughput = RunningStat() @@ -40,6 +45,9 @@ def __init__(self, self.sgd_minibatch_size = sgd_minibatch_size self.train_batch_size = train_batch_size self.learner_stats = {} + self.policies = dict(self.workers.local_worker() + .foreach_trainable_policy(lambda p, i: (i, p))) + logger.debug("Policies to train: {}".format(self.policies)) @override(PolicyOptimizer) def step(self): @@ -63,16 +71,44 @@ def step(self): samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) - with self.grad_timer: - for i in range(self.num_sgd_iter): - for minibatch in self._minibatches(samples): - fetches = self.workers.local_worker().learn_on_batch( - minibatch) - self.learner_stats = get_learner_stats(fetches) - if self.num_sgd_iter > 1: - logger.debug("{} {}".format(i, fetches)) - self.grad_timer.push_units_processed(samples.count) + # Handle everything as if multiagent + if isinstance(samples, SampleBatch): + samples = MultiAgentBatch({ + DEFAULT_POLICY_ID: samples + }, samples.count) + fetches = {} + with self.grad_timer: + for policy_id, policy in self.policies.items(): + if policy_id not in samples.policy_batches: + continue + + batch = samples.policy_batches[policy_id] + for field in self.standardize_fields: + value = batch[field] + standardized = (value - value.mean()) / max( + 1e-4, value.std()) + batch[field] = standardized + + for i in range(self.num_sgd_iter): + iter_extra_fetches = defaultdict(list) + for minibatch in self._minibatches(batch): + batch_fetches = ( + self.workers.local_worker().learn_on_batch( + MultiAgentBatch({ + policy_id: minibatch + }, minibatch.count)))[policy_id] + for k, v in batch_fetches[LEARNER_STATS_KEY].items(): + iter_extra_fetches[k].append(v) + logger.debug("{} {}".format(i, + _averaged(iter_extra_fetches))) + fetches[policy_id] = _averaged(iter_extra_fetches) + + self.grad_timer.push_units_processed(samples.count) + if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: + self.learner_stats = fetches[DEFAULT_POLICY_ID] + else: + self.learner_stats = fetches self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return self.learner_stats diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 9ecb341e97e6..b8150d3c42ae 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -127,7 +127,8 @@ def postprocess_trajectory(self, episode=None): assert tf.executing_eagerly() if postprocess_fn: - return postprocess_fn(self, samples) + return postprocess_fn(self, samples, other_agent_batches, + episode) else: return samples @@ -224,6 +225,12 @@ def num_state_tensors(self): def get_session(self): return None # None implies eager + def get_placeholder(self, ph): + raise ValueError( + "get_placeholder() is not allowed in eager mode. Try using " + "rllib.utils.tf_ops.make_tf_callable() to write " + "functions that work in both graph and eager mode.") + def loss_initialized(self): return self._loss_initialized