From e4cb71fa799f793800516b7751c7fdca265452f4 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 11 Mar 2019 18:06:46 -0700 Subject: [PATCH] Revert "[wingman -> rllib] IMPALA MultiDiscrete changes (#3967)" This reverts commit 962b17f567d4c7d14b5fdfbb1059d0a5d6dcef8d. --- ci/jenkins_tests/run_rllib_tests.sh | 3 - python/ray/rllib/agents/impala/vtrace.py | 175 +++--------- .../agents/impala/vtrace_policy_graph.py | 144 +++------- python/ray/rllib/agents/impala/vtrace_test.py | 268 ------------------ .../ray/rllib/agents/ppo/appo_policy_graph.py | 160 ++++------- .../ray/rllib/evaluation/policy_evaluator.py | 2 +- python/ray/rllib/models/action_dist.py | 25 -- python/ray/rllib/models/catalog.py | 12 +- 8 files changed, 131 insertions(+), 658 deletions(-) delete mode 100644 python/ray/rllib/agents/impala/vtrace_test.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 91d4acd9ceaa..31f16c64ba86 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -410,6 +410,3 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ --stop='{"timesteps_total": 40000}' \ --ray-object-store-memory=500000000 \ --config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 64, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}' - -docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/rllib/agents/impala/vtrace_test.py diff --git a/python/ray/rllib/agents/impala/vtrace.py b/python/ray/rllib/agents/impala/vtrace.py index 4031ee4c8d45..ac5abf0e6592 100644 --- a/python/ray/rllib/agents/impala/vtrace.py +++ b/python/ray/rllib/agents/impala/vtrace.py @@ -20,12 +20,6 @@ by Espeholt, Soyer, Munos et al. See https://arxiv.org/abs/1802.01561 for the full paper. - -In addition to the original paper's code, changes have been made -to support MultiDiscrete action spaces. behaviour_policy_logits, -target_policy_logits and actions parameters in the entry point -multi_from_logits method accepts lists of tensors instead of just -tensors. """ from __future__ import absolute_import @@ -47,48 +41,29 @@ def log_probs_from_logits_and_actions(policy_logits, actions): - return multi_log_probs_from_logits_and_actions([policy_logits], - [actions])[0] - - -def multi_log_probs_from_logits_and_actions(policy_logits, actions): """Computes action log-probs from policy logits and actions. In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size and - ACTION_SPACE refers to the list of numbers each representing a number of - actions. + NUM_ACTIONS refers to the number of actions. Args: - policy_logits: A list with length of ACTION_SPACE of float32 - tensors of shapes - [T, B, ACTION_SPACE[0]], - ..., - [T, B, ACTION_SPACE[-1]] - with un-normalized log-probabilities parameterizing a softmax policy. - actions: A list with length of ACTION_SPACE of int32 - tensors of shapes - [T, B], - ..., - [T, B] - with actions. + policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with + un-normalized log-probabilities parameterizing a softmax policy. + actions: An int32 tensor of shape [T, B] with actions. Returns: - A list with length of ACTION_SPACE of float32 - tensors of shapes - [T, B], - ..., - [T, B] - corresponding to the sampling log probability - of the chosen action w.r.t. the policy. + A float32 tensor of shape [T, B] corresponding to the sampling log + probability of the chosen action w.r.t. the policy. """ + policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32) + actions = tf.convert_to_tensor(actions, dtype=tf.int32) - log_probs = [] - for i in range(len(policy_logits)): - log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=policy_logits[i], labels=actions[i])) + policy_logits.shape.assert_has_rank(3) + actions.shape.assert_has_rank(2) - return log_probs + return -tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=policy_logits, labels=actions) def from_logits(behaviour_policy_logits, @@ -101,39 +76,6 @@ def from_logits(behaviour_policy_logits, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name='vtrace_from_logits'): - """multi_from_logits wrapper used only for tests""" - - res = multi_from_logits( - [behaviour_policy_logits], [target_policy_logits], [actions], - discounts, - rewards, - values, - bootstrap_value, - clip_rho_threshold=clip_rho_threshold, - clip_pg_rho_threshold=clip_pg_rho_threshold, - name=name) - - return VTraceFromLogitsReturns( - vs=res.vs, - pg_advantages=res.pg_advantages, - log_rhos=res.log_rhos, - behaviour_action_log_probs=tf.squeeze( - res.behaviour_action_log_probs, axis=0), - target_action_log_probs=tf.squeeze( - res.target_action_log_probs, axis=0), - ) - - -def multi_from_logits(behaviour_policy_logits, - target_policy_logits, - actions, - discounts, - rewards, - values, - bootstrap_value, - clip_rho_threshold=1.0, - clip_pg_rho_threshold=1.0, - name='vtrace_from_logits'): r"""V-trace for softmax policies. Calculates V-trace actor critic targets for softmax polices as described in @@ -148,30 +90,16 @@ def multi_from_logits(behaviour_policy_logits, In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size and - ACTION_SPACE refers to the list of numbers each representing a number of - actions. + NUM_ACTIONS refers to the number of actions. Args: - behaviour_policy_logits: A list with length of ACTION_SPACE of float32 - tensors of shapes - [T, B, ACTION_SPACE[0]], - ..., - [T, B, ACTION_SPACE[-1]] - with un-normalized log-probabilities parameterizing the softmax behaviour - policy. - target_policy_logits: A list with length of ACTION_SPACE of float32 - tensors of shapes - [T, B, ACTION_SPACE[0]], - ..., - [T, B, ACTION_SPACE[-1]] - with un-normalized log-probabilities parameterizing the softmax target + behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with + un-normalized log-probabilities parametrizing the softmax behaviour policy. - actions: A list with length of ACTION_SPACE of int32 - tensors of shapes - [T, B], - ..., - [T, B] - with actions sampled from the behaviour policy. + target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with + un-normalized log-probabilities parametrizing the softmax target policy. + actions: An int32 tensor of shape [T, B] of actions sampled from the + behaviour policy. discounts: A float32 tensor of shape [T, B] with the discount encountered when following the behaviour policy. rewards: A float32 tensor of shape [T, B] with the rewards generated by @@ -200,19 +128,17 @@ def multi_from_logits(behaviour_policy_logits, target_action_log_probs: A float32 tensor of shape [T, B] containing target policy action probabilities (log \pi(a_t)). """ - - for i in range(len(behaviour_policy_logits)): - behaviour_policy_logits[i] = tf.convert_to_tensor( - behaviour_policy_logits[i], dtype=tf.float32) - target_policy_logits[i] = tf.convert_to_tensor( - target_policy_logits[i], dtype=tf.float32) - actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32) - - # Make sure tensor ranks are as expected. - # The rest will be checked by from_action_log_probs. - behaviour_policy_logits[i].shape.assert_has_rank(3) - target_policy_logits[i].shape.assert_has_rank(3) - actions[i].shape.assert_has_rank(2) + behaviour_policy_logits = tf.convert_to_tensor( + behaviour_policy_logits, dtype=tf.float32) + target_policy_logits = tf.convert_to_tensor( + target_policy_logits, dtype=tf.float32) + actions = tf.convert_to_tensor(actions, dtype=tf.int32) + + # Make sure tensor ranks are as expected. + # The rest will be checked by from_action_log_probs. + behaviour_policy_logits.shape.assert_has_rank(3) + target_policy_logits.shape.assert_has_rank(3) + actions.shape.assert_has_rank(2) with tf.name_scope( name, @@ -220,14 +146,11 @@ def multi_from_logits(behaviour_policy_logits, behaviour_policy_logits, target_policy_logits, actions, discounts, rewards, values, bootstrap_value ]): - target_action_log_probs = multi_log_probs_from_logits_and_actions( + target_action_log_probs = log_probs_from_logits_and_actions( target_policy_logits, actions) - behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( + behaviour_action_log_probs = log_probs_from_logits_and_actions( behaviour_policy_logits, actions) - - log_rhos = get_log_rhos(target_action_log_probs, - behaviour_action_log_probs) - + log_rhos = target_action_log_probs - behaviour_action_log_probs vtrace_returns = from_importance_weights( log_rhos=log_rhos, discounts=discounts, @@ -236,7 +159,6 @@ def multi_from_logits(behaviour_policy_logits, bootstrap_value=bootstrap_value, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) - return VTraceFromLogitsReturns( log_rhos=log_rhos, behaviour_action_log_probs=behaviour_action_log_probs, @@ -261,13 +183,13 @@ def from_importance_weights(log_rhos, by Espeholt, Soyer, Munos et al. In the notation used throughout documentation and comments, T refers to the - time dimension ranging from 0 to T-1. B refers to the batch size. This code - also supports the case where all tensors have the same number of additional - dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C], - `bootstrap_value` is [B, C]. + time dimension ranging from 0 to T-1. B refers to the batch size and + NUM_ACTIONS refers to the number of actions. This code also supports the + case where all tensors have the same number of additional dimensions, e.g., + `rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C]. Args: - log_rhos: A float32 tensor of shape [T, B] representing the + log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the log importance sampling weights, i.e. log(target_policy(a) / behaviour_policy(a)). V-trace performs operations on rhos in log-space for numerical stability. @@ -324,14 +246,6 @@ def from_importance_weights(log_rhos, if clip_rho_threshold is not None: clipped_rhos = tf.minimum( clip_rho_threshold, rhos, name='clipped_rhos') - - tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos)) - tf.summary.scalar( - 'num_of_clipped_rhos', - tf.reduce_sum( - tf.cast( - tf.equal(clipped_rhos, clip_rho_threshold), tf.int32))) - tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos)) else: clipped_rhos = rhos @@ -384,16 +298,3 @@ def scanfunc(acc, sequence_item): return VTraceReturns( vs=tf.stop_gradient(vs), pg_advantages=tf.stop_gradient(pg_advantages)) - - -def get_log_rhos(behaviour_action_log_probs, target_action_log_probs): - """With the selected log_probs for multi-discrete actions of behaviour - and target policies we compute the log_rhos for calculating the vtrace.""" - log_rhos = [ - t - b - for t, b in zip(target_action_log_probs, behaviour_action_log_probs) - ] - log_rhos = [tf.convert_to_tensor(l, dtype=tf.float32) for l in log_rhos] - log_rhos = tf.reduce_sum(tf.stack(log_rhos), axis=0) - - return log_rhos diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 9d16c337d1eb..7f36e78f75c8 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -6,19 +6,19 @@ from __future__ import division from __future__ import print_function +import tensorflow as tf import gym + import ray -import numpy as np -import tensorflow as tf from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule -from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.models.action_dist import Categorical class VTraceLoss(object): @@ -45,20 +45,12 @@ def __init__(self, handle episode cut boundaries. Args: - actions: An int32 tensor of shape [T, B, ACTION_SPACE]. + actions: An int32 tensor of shape [T, B, NUM_ACTIONS]. actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. - behaviour_logits: A list with length of ACTION_SPACE of float32 - tensors of shapes - [T, B, ACTION_SPACE[0]], - ..., - [T, B, ACTION_SPACE[-1]] - target_logits: A list with length of ACTION_SPACE of float32 - tensors of shapes - [T, B, ACTION_SPACE[0]], - ..., - [T, B, ACTION_SPACE[-1]] + behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. + target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. @@ -68,10 +60,10 @@ def __init__(self, # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): - self.vtrace_returns = vtrace.multi_from_logits( + self.vtrace_returns = vtrace.from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, - actions=tf.unstack(tf.cast(actions, tf.int32), axis=2), + actions=tf.cast(actions, tf.int32), discounts=tf.to_float(~dones) * discount, rewards=rewards, values=values, @@ -109,20 +101,6 @@ def __init__(self, "Must use `truncate_episodes` batch mode with V-trace." self.config = config self.sess = tf.get_default_session() - self.grads = None - - if isinstance(action_space, gym.spaces.Discrete): - is_multidiscrete = False - actions_shape = [None] - output_hidden_shape = [action_space.n] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - actions_shape = [None, len(action_space.nvec)] - output_hidden_shape = action_space.nvec.astype(np.int32) - else: - raise UnsupportedSpaceException( - "Action space {} is not supported for IMPALA.".format( - action_space)) # Create input placeholders if existing_inputs: @@ -131,21 +109,22 @@ def __init__(self, existing_state_in = existing_inputs[7:-1] existing_seq_lens = existing_inputs[-1] else: - actions = tf.placeholder(tf.int64, actions_shape, name="ac") + if isinstance(action_space, gym.spaces.Discrete): + ac_size = action_space.n + actions = tf.placeholder(tf.int64, [None], name="ac") + else: + raise UnsupportedSpaceException( + "Action space {} is not supported for IMPALA.".format( + action_space)) dones = tf.placeholder(tf.bool, [None], name="dones") rewards = tf.placeholder(tf.float32, [None], name="rewards") behaviour_logits = tf.placeholder( - tf.float32, [None, sum(output_hidden_shape)], - name="behaviour_logits") + tf.float32, [None, ac_size], name="behaviour_logits") observations = tf.placeholder( tf.float32, [None] + list(observation_space.shape)) existing_state_in = None existing_seq_lens = None - # Unpack behaviour logits - unpacked_behaviour_logits = tf.split( - behaviour_logits, output_hidden_shape, axis=1) - # Setup the policy dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) @@ -164,30 +143,12 @@ def __init__(self, self.config["model"], state_in=existing_state_in, seq_lens=existing_seq_lens) - unpacked_outputs = tf.split( - self.model.outputs, output_hidden_shape, axis=1) - - dist_inputs = unpacked_outputs if is_multidiscrete else \ - self.model.outputs - action_dist = dist_class(dist_inputs) - + action_dist = dist_class(self.model.outputs) values = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - def make_time_major(tensor, drop_last=False): - """Swaps batch and trajectory axis. - Args: - tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last - trajectory item. - Returns: - res: A tensor with swapped axes or a list of tensors with - swapped axes. - """ - if isinstance(tensor, list): - return [make_time_major(t, drop_last) for t in tensor] - + def to_batches(tensor): if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B @@ -198,16 +159,11 @@ def make_time_major(tensor, drop_last=False): B = tf.shape(tensor)[0] // T rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) - # swap B and T axes - res = tf.transpose( + return tf.transpose( rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - if drop_last: - return res[:-1] - return res - if self.model.state_in: max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) @@ -215,52 +171,31 @@ def make_time_major(tensor, drop_last=False): else: mask = tf.ones_like(rewards, dtype=tf.bool) - # Prepare actions for loss - loss_actions = actions if is_multidiscrete else tf.expand_dims( - actions, axis=1) - # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. self.loss = VTraceLoss( - actions=make_time_major(loss_actions, drop_last=True), - actions_logp=make_time_major( - action_dist.logp(actions), drop_last=True), - actions_entropy=make_time_major( - action_dist.entropy(), drop_last=True), - dones=make_time_major(dones, drop_last=True), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=True), - target_logits=make_time_major(unpacked_outputs, drop_last=True), + actions=to_batches(actions)[:-1], + actions_logp=to_batches(action_dist.logp(actions))[:-1], + actions_entropy=to_batches(action_dist.entropy())[:-1], + dones=to_batches(dones)[:-1], + behaviour_logits=to_batches(behaviour_logits)[:-1], + target_logits=to_batches(self.model.outputs)[:-1], discount=config["gamma"], - rewards=make_time_major(rewards, drop_last=True), - values=make_time_major(values, drop_last=True), - bootstrap_value=make_time_major(values)[-1], - valid_mask=make_time_major(mask, drop_last=True), + rewards=to_batches(rewards)[:-1], + values=to_batches(values)[:-1], + bootstrap_value=to_batches(values)[-1], + valid_mask=to_batches(mask)[:-1], vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) - - kls = model_dist.kl(behaviour_dist) - if len(kls) > 1: - self.KL_stats = {} - - for i, kl in enumerate(kls): - self.KL_stats.update({ - "mean_KL_{}".format(i): tf.reduce_mean(kl), - "max_KL_{}".format(i): tf.reduce_max(kl), - "median_KL_{}".format(i): tf.contrib.distributions. - percentile(kl, 50.0), - }) - else: - self.KL_stats = { - "mean_KL": tf.reduce_mean(kls[0]), - "max_KL": tf.reduce_max(kls[0]), - "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0), - } + model_dist = Categorical(self.model.outputs) + behaviour_dist = Categorical(behaviour_logits) + self.KLs = model_dist.kl(behaviour_dist) + self.mean_KL = tf.reduce_mean(self.KLs) + self.max_KL = tf.reduce_max(self.KLs) + self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0) # Initialize TFPolicyGraph loss_in = [ @@ -296,7 +231,7 @@ def make_time_major(tensor, drop_last=False): self.sess.run(tf.global_variables_initializer()) self.stats_fetches = { - "stats": dict({ + "stats": { "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, @@ -305,8 +240,11 @@ def make_time_major(tensor, drop_last=False): "vf_loss": self.loss.vf_loss, "vf_explained_var": explained_variance( tf.reshape(self.loss.vtrace_returns.vs, [-1]), - tf.reshape(make_time_major(values, drop_last=True), [-1])), - }, **self.KL_stats), + tf.reshape(to_batches(values)[:-1], [-1])), + "mean_KL": self.mean_KL, + "max_KL": self.max_KL, + "median_KL": self.median_KL, + }, } @override(TFPolicyGraph) diff --git a/python/ray/rllib/agents/impala/vtrace_test.py b/python/ray/rllib/agents/impala/vtrace_test.py deleted file mode 100644 index f74798fffdbb..000000000000 --- a/python/ray/rllib/agents/impala/vtrace_test.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for V-trace. - -For details and theory see: - -"IMPALA: Scalable Distributed Deep-RL with -Importance Weighted Actor-Learner Architectures" -by Espeholt, Soyer, Munos et al. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np -import tensorflow as tf -import vtrace - - -def _shaped_arange(*shape): - """Runs np.arange, converts to float and reshapes.""" - return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) - - -def _softmax(logits): - """Applies softmax non-linearity on inputs.""" - return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) - - -def _ground_truth_calculation(discounts, log_rhos, rewards, values, - bootstrap_value, clip_rho_threshold, - clip_pg_rho_threshold): - """Calculates the ground truth for V-trace in Python/Numpy.""" - vs = [] - seq_len = len(discounts) - rhos = np.exp(log_rhos) - cs = np.minimum(rhos, 1.0) - clipped_rhos = rhos - if clip_rho_threshold: - clipped_rhos = np.minimum(rhos, clip_rho_threshold) - clipped_pg_rhos = rhos - if clip_pg_rho_threshold: - clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) - - # This is a very inefficient way to calculate the V-trace ground truth. - # We calculate it this way because it is close to the mathematical notation - # of - # V-trace. - # v_s = V(x_s) - # + \sum^{T-1}_{t=s} \gamma^{t-s} - # * \prod_{i=s}^{t-1} c_i - # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) - # Note that when we take the product over c_i, we write `s:t` as the - # notation - # of the paper is inclusive of the `t-1`, but Python is exclusive. - # Also note that np.prod([]) == 1. - values_t_plus_1 = np.concatenate( - [values, bootstrap_value[None, :]], axis=0) - for s in range(seq_len): - v_s = np.copy(values[s]) # Very important copy. - for t in range(s, seq_len): - v_s += (np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], axis=0) - * clipped_rhos[t] * (rewards[t] + discounts[t] * - values_t_plus_1[t + 1] - values[t])) - vs.append(v_s) - vs = np.stack(vs, axis=0) - pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate( - [vs[1:], bootstrap_value[None, :]], axis=0) - values)) - - return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) - - -class LogProbsFromLogitsAndActionsTest(tf.test.TestCase, - parameterized.TestCase): - @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2)) - def test_log_probs_from_logits_and_actions(self, batch_size): - """Tests log_probs_from_logits_and_actions.""" - seq_len = 7 - num_actions = 3 - - policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 - actions = np.random.randint( - 0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32) - - action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( - policy_logits, actions) - - # Ground Truth - # Using broadcasting to create a mask that indexes action logits - action_index_mask = actions[..., None] == np.arange(num_actions) - - def index_with_mask(array, mask): - return array[mask].reshape(*array.shape[:-1]) - - # Note: Normally log(softmax) is not a good idea because it's not - # numerically stable. However, in this test we have well-behaved - # values. - ground_truth_v = index_with_mask( - np.log(_softmax(policy_logits)), action_index_mask) - - with self.test_session() as session: - self.assertAllClose(ground_truth_v, - session.run(action_log_probs_tensor)) - - -class VtraceTest(tf.test.TestCase, parameterized.TestCase): - @parameterized.named_parameters(('Batch1', 1), ('Batch5', 5)) - def test_vtrace(self, batch_size): - """Tests V-trace against ground truth data calculated in python.""" - seq_len = 5 - - # Create log_rhos such that rho will span from near-zero to above the - # clipping thresholds. In particular, calculate log_rhos in - # [-2.5, 2.5), - # so that rho is in approx [0.08, 12.2). - log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) - log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). - values = { - 'log_rhos': log_rhos, - # T, B where B_i: [0.9 / (i+1)] * T - 'discounts': np.array([[0.9 / (b + 1) for b in range(batch_size)] - for _ in range(seq_len)]), - 'rewards': _shaped_arange(seq_len, batch_size), - 'values': _shaped_arange(seq_len, batch_size) / batch_size, - 'bootstrap_value': _shaped_arange(batch_size) + 1.0, - 'clip_rho_threshold': 3.7, - 'clip_pg_rho_threshold': 2.2, - } - - output = vtrace.from_importance_weights(**values) - - with self.test_session() as session: - output_v = session.run(output) - - ground_truth_v = _ground_truth_calculation(**values) - for a, b in zip(ground_truth_v, output_v): - self.assertAllClose(a, b) - - @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2)) - def test_vtrace_from_logits(self, batch_size): - """Tests V-trace calculated from logits.""" - seq_len = 5 - num_actions = 3 - clip_rho_threshold = None # No clipping. - clip_pg_rho_threshold = None # No clipping. - - # Intentionally leaving shapes unspecified to test if V-trace can - # deal with that. - placeholders = { - # T, B, NUM_ACTIONS - 'behaviour_policy_logits': tf.placeholder( - dtype=tf.float32, shape=[None, None, None]), - # T, B, NUM_ACTIONS - 'target_policy_logits': tf.placeholder( - dtype=tf.float32, shape=[None, None, None]), - 'actions': tf.placeholder(dtype=tf.int32, shape=[None, None]), - 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None]), - 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None]), - 'values': tf.placeholder(dtype=tf.float32, shape=[None, None]), - 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]), - } - - from_logits_output = vtrace.from_logits( - clip_rho_threshold=clip_rho_threshold, - clip_pg_rho_threshold=clip_pg_rho_threshold, - **placeholders) - - target_log_probs = vtrace.log_probs_from_logits_and_actions( - placeholders['target_policy_logits'], placeholders['actions']) - behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( - placeholders['behaviour_policy_logits'], placeholders['actions']) - log_rhos = target_log_probs - behaviour_log_probs - ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) - - values = { - 'behaviour_policy_logits': _shaped_arange(seq_len, batch_size, - num_actions), - 'target_policy_logits': _shaped_arange(seq_len, batch_size, - num_actions), - 'actions': np.random.randint( - 0, num_actions - 1, size=(seq_len, batch_size)), - 'discounts': np.array( # T, B where B_i: [0.9 / (i+1)] * T - [[0.9 / (b + 1) for b in range(batch_size)] - for _ in range(seq_len)]), - 'rewards': _shaped_arange(seq_len, batch_size), - 'values': _shaped_arange(seq_len, batch_size) / batch_size, - 'bootstrap_value': _shaped_arange(batch_size) + 1.0, # B - } - - feed_dict = {placeholders[k]: v for k, v in values.items()} - with self.test_session() as session: - from_logits_output_v = session.run( - from_logits_output, feed_dict=feed_dict) - (ground_truth_log_rhos, ground_truth_behaviour_action_log_probs, - ground_truth_target_action_log_probs) = session.run( - ground_truth, feed_dict=feed_dict) - - # Calculate V-trace using the ground truth logits. - from_iw = vtrace.from_importance_weights( - log_rhos=ground_truth_log_rhos, - discounts=values['discounts'], - rewards=values['rewards'], - values=values['values'], - bootstrap_value=values['bootstrap_value'], - clip_rho_threshold=clip_rho_threshold, - clip_pg_rho_threshold=clip_pg_rho_threshold) - - with self.test_session() as session: - from_iw_v = session.run(from_iw) - - self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs) - self.assertAllClose(from_iw_v.pg_advantages, - from_logits_output_v.pg_advantages) - self.assertAllClose(ground_truth_behaviour_action_log_probs, - from_logits_output_v.behaviour_action_log_probs) - self.assertAllClose(ground_truth_target_action_log_probs, - from_logits_output_v.target_action_log_probs) - self.assertAllClose(ground_truth_log_rhos, - from_logits_output_v.log_rhos) - - def test_higher_rank_inputs_for_importance_weights(self): - """Checks support for additional dimensions in inputs.""" - placeholders = { - 'log_rhos': tf.placeholder( - dtype=tf.float32, shape=[None, None, 1]), - 'discounts': tf.placeholder( - dtype=tf.float32, shape=[None, None, 1]), - 'rewards': tf.placeholder( - dtype=tf.float32, shape=[None, None, 42]), - 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), - 'bootstrap_value': tf.placeholder( - dtype=tf.float32, shape=[None, 42]) - } - output = vtrace.from_importance_weights(**placeholders) - self.assertEqual(output.vs.shape.as_list()[-1], 42) - - def test_inconsistent_rank_inputs_for_importance_weights(self): - """Test one of many possible errors in shape of inputs.""" - placeholders = { - 'log_rhos': tf.placeholder( - dtype=tf.float32, shape=[None, None, 1]), - 'discounts': tf.placeholder( - dtype=tf.float32, shape=[None, None, 1]), - 'rewards': tf.placeholder( - dtype=tf.float32, shape=[None, None, 42]), - 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), - # Should be [None, 42]. - 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]) - } - with self.assertRaisesRegexp(ValueError, 'must have rank 2'): - vtrace.from_importance_weights(**placeholders) - - -if __name__ == '__main__': - tf.test.main() diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index 378e089c5d0d..4e4b2480a7ef 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -6,7 +6,6 @@ from __future__ import division from __future__ import print_function -import numpy as np import tensorflow as tf import logging import gym @@ -18,7 +17,7 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.models.action_dist import MultiCategorical +from ray.rllib.models.action_dist import Categorical from ray.rllib.evaluation.postprocessing import compute_advantages logger = logging.getLogger(__name__) @@ -30,7 +29,7 @@ class PPOSurrogateLoss(object): Arguments: prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. - action_kl: A float32 tensor of shape [T, B]. + actions_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. valid_mask: A bool tensor of valid RNN input elements (#2992). @@ -105,7 +104,7 @@ def __init__(self, actions: An int32 tensor of shape [T, B, NUM_ACTIONS]. prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. - action_kl: A float32 tensor of shape [T, B]. + actions_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. @@ -119,10 +118,10 @@ def __init__(self, # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): - self.vtrace_returns = vtrace.multi_from_logits( + self.vtrace_returns = vtrace.from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, - actions=tf.unstack(tf.cast(actions, tf.int32), axis=2), + actions=tf.cast(actions, tf.int32), discounts=tf.to_float(~dones) * discount, rewards=rewards, values=values, @@ -167,21 +166,6 @@ def __init__(self, "Must use `truncate_episodes` batch mode with V-trace." self.config = config self.sess = tf.get_default_session() - self.grads = None - - if isinstance(action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [action_space.n] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = action_space.nvec.astype(np.int32) - elif self.config["vtrace"]: - raise UnsupportedSpaceException( - "Action space {} is not supported for APPO + VTrace.", - format(action_space)) - else: - is_multidiscrete = False - output_hidden_shape = 1 # Policy network model dist_class, logit_dim = ModelCatalog.get_action_dist( @@ -202,6 +186,11 @@ def __init__(self, existing_seq_lens = existing_inputs[-1] else: actions = ModelCatalog.get_action_placeholder(action_space) + if (not isinstance(action_space, gym.spaces.Discrete) + and self.config["vtrace"]): + raise UnsupportedSpaceException( + "Action space {} is not supported with vtrace.".format( + action_space)) dones = tf.placeholder(tf.bool, [None], name="dones") rewards = tf.placeholder(tf.float32, [None], name="rewards") behaviour_logits = tf.placeholder( @@ -210,7 +199,6 @@ def __init__(self, tf.float32, [None] + list(observation_space.shape)) existing_state_in = None existing_seq_lens = None - if not self.config["vtrace"]: adv_ph = tf.placeholder( tf.float32, name="advantages", shape=(None, )) @@ -218,13 +206,7 @@ def __init__(self, tf.float32, name="value_targets", shape=(None, )) self.observations = observations - # Unpack behaviour logits - unpacked_behaviour_logits = tf.split( - behaviour_logits, output_hidden_shape, axis=1) - # Setup the policy - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) prev_actions = ModelCatalog.get_action_placeholder(action_space) prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") self.model = ModelCatalog.get_model( @@ -232,7 +214,6 @@ def __init__(self, "obs": observations, "prev_actions": prev_actions, "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), }, observation_space, action_space, @@ -240,35 +221,16 @@ def __init__(self, self.config["model"], state_in=existing_state_in, seq_lens=existing_seq_lens) - unpacked_outputs = tf.split( - self.model.outputs, output_hidden_shape, axis=1) - - dist_inputs = unpacked_outputs if is_multidiscrete else \ - self.model.outputs - prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ - behaviour_logits - action_dist = dist_class(dist_inputs) - prev_action_dist = dist_class(prev_dist_inputs) + action_dist = dist_class(self.model.outputs) + prev_action_dist = dist_class(behaviour_logits) values = self.model.value_function() self.value_function = values self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - def make_time_major(tensor, drop_last=False): - """Swaps batch and trajectory axis. - Args: - tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last - trajectory item. - Returns: - res: A tensor with swapped axes or a list of tensors with - swapped axes. - """ - if isinstance(tensor, list): - return [make_time_major(t, drop_last) for t in tensor] - + def to_batches(tensor): if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B @@ -279,16 +241,11 @@ def make_time_major(tensor, drop_last=False): B = tf.shape(tensor)[0] // T rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) - # swap B and T axes - res = tf.transpose( + return tf.transpose( rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - if drop_last: - return res[:-1] - return res - if self.model.state_in: max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) @@ -299,30 +256,21 @@ def make_time_major(tensor, drop_last=False): # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. if self.config["vtrace"]: logger.info("Using V-Trace surrogate loss (vtrace=True)") - - # Prepare actions for loss - loss_actions = actions if is_multidiscrete else tf.expand_dims( - actions, axis=1) - self.loss = VTraceSurrogateLoss( - actions=make_time_major(loss_actions, drop_last=True), - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions), drop_last=True), - actions_logp=make_time_major( - action_dist.logp(actions), drop_last=True), + actions=to_batches(actions)[:-1], + prev_actions_logp=to_batches( + prev_action_dist.logp(actions))[:-1], + actions_logp=to_batches(action_dist.logp(actions))[:-1], action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major( - action_dist.entropy(), drop_last=True), - dones=make_time_major(dones, drop_last=True), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=True), - target_logits=make_time_major( - unpacked_outputs, drop_last=True), + actions_entropy=to_batches(action_dist.entropy())[:-1], + dones=to_batches(dones)[:-1], + behaviour_logits=to_batches(behaviour_logits)[:-1], + target_logits=to_batches(self.model.outputs)[:-1], discount=config["gamma"], - rewards=make_time_major(rewards, drop_last=True), - values=make_time_major(values, drop_last=True), - bootstrap_value=make_time_major(values)[-1], - valid_mask=make_time_major(mask, drop_last=True), + rewards=to_batches(rewards)[:-1], + values=to_batches(values)[:-1], + bootstrap_value=to_batches(values)[-1], + valid_mask=to_batches(mask)[:-1], vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], @@ -332,41 +280,25 @@ def make_time_major(tensor, drop_last=False): else: logger.info("Using PPO surrogate loss (vtrace=False)") self.loss = PPOSurrogateLoss( - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions)), - actions_logp=make_time_major(action_dist.logp(actions)), + prev_actions_logp=to_batches(prev_action_dist.logp(actions)), + actions_logp=to_batches(action_dist.logp(actions)), action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major(action_dist.entropy()), - values=make_time_major(values), - valid_mask=make_time_major(mask), - advantages=make_time_major(adv_ph), - value_targets=make_time_major(value_targets), + actions_entropy=to_batches(action_dist.entropy()), + values=to_batches(values), + valid_mask=to_batches(mask), + advantages=to_batches(adv_ph), + value_targets=to_batches(value_targets), vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], clip_param=self.config["clip_param"]) # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) - - kls = model_dist.kl(behaviour_dist) - if len(kls) > 1: - self.KL_stats = {} - - for i, kl in enumerate(kls): - self.KL_stats.update({ - "mean_KL_{}".format(i): tf.reduce_mean(kl), - "max_KL_{}".format(i): tf.reduce_max(kl), - "median_KL_{}".format(i): tf.contrib.distributions. - percentile(kl, 50.0), - }) - else: - self.KL_stats = { - "mean_KL": tf.reduce_mean(kls[0]), - "max_KL": tf.reduce_max(kls[0]), - "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0), - } - + model_dist = Categorical(self.model.outputs) + behaviour_dist = Categorical(behaviour_logits) + self.KLs = model_dist.kl(behaviour_dist) + self.mean_KL = tf.reduce_mean(self.KLs) + self.max_KL = tf.reduce_max(self.KLs) + self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0) # Initialize TFPolicyGraph loss_in = [ ("actions", actions), @@ -403,10 +335,12 @@ def make_time_major(tensor, drop_last=False): self.sess.run(tf.global_variables_initializer()) - values_batched = make_time_major( - values, drop_last=self.config["vtrace"]) + if self.config["vtrace"]: + values_batched = to_batches(values)[:-1] + else: + values_batched = to_batches(values) self.stats_fetches = { - "stats": dict({ + "stats": { "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, @@ -416,8 +350,12 @@ def make_time_major(tensor, drop_last=False): "vf_explained_var": explained_variance( tf.reshape(self.loss.value_targets, [-1]), tf.reshape(values_batched, [-1])), - }, **self.KL_stats), + "mean_KL": self.mean_KL, + "max_KL": self.max_KL, + "median_KL": self.median_KL, + }, } + self.stats_fetches["kl"] = self.loss.mean_kl def optimizer(self): if self.config["opt_type"] == "adam": diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index e83b6a209a46..9bf2cca532c2 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -662,7 +662,7 @@ def _build_policy_map(self, policy_dict, policy_config): return policy_map, preprocessors def __del__(self): - if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler): + if isinstance(self.sampler, AsyncSampler): self.sampler.shutdown = True diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 138fd9f8a6a8..724e54fd1fac 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -114,31 +114,6 @@ def _build_sample_op(self): return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1) -class MultiCategorical(ActionDistribution): - """Categorical distribution for discrete action spaces.""" - - def __init__(self, inputs): - self.cats = [Categorical(input_) for input_ in inputs] - self.sample_op = self._build_sample_op() - - def logp(self, actions): - # If tensor is provided, unstack it into list - if isinstance(actions, tf.Tensor): - actions = tf.unstack(actions, axis=1) - logps = tf.stack( - [cat.logp(act) for cat, act in zip(self.cats, actions)]) - return tf.reduce_sum(logps, axis=0) - - def entropy(self): - return tf.stack([cat.entropy() for cat in self.cats], axis=1) - - def kl(self, other): - return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)] - - def _build_sample_op(self): - return tf.stack([cat.sample() for cat in self.cats], axis=1) - - class DiagGaussian(ActionDistribution): """Action distribution where each vector element is a gaussian. diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index bcd79cbfe092..73b55675c8f7 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -12,8 +12,8 @@ _global_registry from ray.rllib.models.extra_spaces import Simplex -from ray.rllib.models.action_dist import (Categorical, MultiCategorical, - Deterministic, DiagGaussian, +from ray.rllib.models.action_dist import (Categorical, Deterministic, + DiagGaussian, MultiActionDistribution, Dirichlet) from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.fcnet import FullyConnectedNetwork @@ -136,9 +136,6 @@ def get_action_dist(action_space, config, dist_type=None): input_lens=input_lens), sum(input_lens) elif isinstance(action_space, Simplex): return Dirichlet, action_space.shape[0] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - return MultiCategorical, sum(action_space.nvec) - raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) @@ -174,11 +171,6 @@ def get_action_placeholder(action_space): elif isinstance(action_space, Simplex): return tf.placeholder( tf.float32, shape=(None, action_space.shape[0]), name="action") - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - return tf.placeholder( - tf.as_dtype(action_space.dtype), - shape=(None, len(action_space.nvec)), - name="action") else: raise NotImplementedError("action space {}" " not supported".format(action_space))