Skip to content

Commit

Permalink
Revert "[wingman -> rllib] IMPALA MultiDiscrete changes (#3967)" (#4332)
Browse files Browse the repository at this point in the history
This reverts commit 962b17f.
  • Loading branch information
ericl authored Mar 12, 2019
1 parent 7ff56ce commit 3c41cb9
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 658 deletions.
3 changes: 0 additions & 3 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,3 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--stop='{"timesteps_total": 40000}' \
--ray-object-store-memory=500000000 \
--config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 64, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/agents/impala/vtrace_test.py
175 changes: 38 additions & 137 deletions python/ray/rllib/agents/impala/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
by Espeholt, Soyer, Munos et al.
See https://arxiv.org/abs/1802.01561 for the full paper.
In addition to the original paper's code, changes have been made
to support MultiDiscrete action spaces. behaviour_policy_logits,
target_policy_logits and actions parameters in the entry point
multi_from_logits method accepts lists of tensors instead of just
tensors.
"""

from __future__ import absolute_import
Expand All @@ -47,48 +41,29 @@


def log_probs_from_logits_and_actions(policy_logits, actions):
return multi_log_probs_from_logits_and_actions([policy_logits],
[actions])[0]


def multi_log_probs_from_logits_and_actions(policy_logits, actions):
"""Computes action log-probs from policy logits and actions.
In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
ACTION_SPACE refers to the list of numbers each representing a number of
actions.
NUM_ACTIONS refers to the number of actions.
Args:
policy_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing a softmax policy.
actions: A list with length of ACTION_SPACE of int32
tensors of shapes
[T, B],
...,
[T, B]
with actions.
policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parameterizing a softmax policy.
actions: An int32 tensor of shape [T, B] with actions.
Returns:
A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B],
...,
[T, B]
corresponding to the sampling log probability
of the chosen action w.r.t. the policy.
A float32 tensor of shape [T, B] corresponding to the sampling log
probability of the chosen action w.r.t. the policy.
"""
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)

log_probs = []
for i in range(len(policy_logits)):
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits[i], labels=actions[i]))
policy_logits.shape.assert_has_rank(3)
actions.shape.assert_has_rank(2)

return log_probs
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits, labels=actions)


def from_logits(behaviour_policy_logits,
Expand All @@ -101,39 +76,6 @@ def from_logits(behaviour_policy_logits,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name='vtrace_from_logits'):
"""multi_from_logits wrapper used only for tests"""

res = multi_from_logits(
[behaviour_policy_logits], [target_policy_logits], [actions],
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
name=name)

return VTraceFromLogitsReturns(
vs=res.vs,
pg_advantages=res.pg_advantages,
log_rhos=res.log_rhos,
behaviour_action_log_probs=tf.squeeze(
res.behaviour_action_log_probs, axis=0),
target_action_log_probs=tf.squeeze(
res.target_action_log_probs, axis=0),
)


def multi_from_logits(behaviour_policy_logits,
target_policy_logits,
actions,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name='vtrace_from_logits'):
r"""V-trace for softmax policies.
Calculates V-trace actor critic targets for softmax polices as described in
Expand All @@ -148,30 +90,16 @@ def multi_from_logits(behaviour_policy_logits,
In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
ACTION_SPACE refers to the list of numbers each representing a number of
actions.
NUM_ACTIONS refers to the number of actions.
Args:
behaviour_policy_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing the softmax behaviour
policy.
target_policy_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing the softmax target
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parametrizing the softmax behaviour
policy.
actions: A list with length of ACTION_SPACE of int32
tensors of shapes
[T, B],
...,
[T, B]
with actions sampled from the behaviour policy.
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parametrizing the softmax target policy.
actions: An int32 tensor of shape [T, B] of actions sampled from the
behaviour policy.
discounts: A float32 tensor of shape [T, B] with the discount encountered
when following the behaviour policy.
rewards: A float32 tensor of shape [T, B] with the rewards generated by
Expand Down Expand Up @@ -200,34 +128,29 @@ def multi_from_logits(behaviour_policy_logits,
target_action_log_probs: A float32 tensor of shape [T, B] containing
target policy action probabilities (log \pi(a_t)).
"""

for i in range(len(behaviour_policy_logits)):
behaviour_policy_logits[i] = tf.convert_to_tensor(
behaviour_policy_logits[i], dtype=tf.float32)
target_policy_logits[i] = tf.convert_to_tensor(
target_policy_logits[i], dtype=tf.float32)
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy_logits[i].shape.assert_has_rank(3)
target_policy_logits[i].shape.assert_has_rank(3)
actions[i].shape.assert_has_rank(2)
behaviour_policy_logits = tf.convert_to_tensor(
behaviour_policy_logits, dtype=tf.float32)
target_policy_logits = tf.convert_to_tensor(
target_policy_logits, dtype=tf.float32)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy_logits.shape.assert_has_rank(3)
target_policy_logits.shape.assert_has_rank(3)
actions.shape.assert_has_rank(2)

with tf.name_scope(
name,
values=[
behaviour_policy_logits, target_policy_logits, actions,
discounts, rewards, values, bootstrap_value
]):
target_action_log_probs = multi_log_probs_from_logits_and_actions(
target_action_log_probs = log_probs_from_logits_and_actions(
target_policy_logits, actions)
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
behaviour_action_log_probs = log_probs_from_logits_and_actions(
behaviour_policy_logits, actions)

log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)

log_rhos = target_action_log_probs - behaviour_action_log_probs
vtrace_returns = from_importance_weights(
log_rhos=log_rhos,
discounts=discounts,
Expand All @@ -236,7 +159,6 @@ def multi_from_logits(behaviour_policy_logits,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)

return VTraceFromLogitsReturns(
log_rhos=log_rhos,
behaviour_action_log_probs=behaviour_action_log_probs,
Expand All @@ -261,13 +183,13 @@ def from_importance_weights(log_rhos,
by Espeholt, Soyer, Munos et al.
In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size. This code
also supports the case where all tensors have the same number of additional
dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
`bootstrap_value` is [B, C].
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions. This code also supports the
case where all tensors have the same number of additional dimensions, e.g.,
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].
Args:
log_rhos: A float32 tensor of shape [T, B] representing the
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the
log importance sampling weights, i.e.
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
on rhos in log-space for numerical stability.
Expand Down Expand Up @@ -324,14 +246,6 @@ def from_importance_weights(log_rhos,
if clip_rho_threshold is not None:
clipped_rhos = tf.minimum(
clip_rho_threshold, rhos, name='clipped_rhos')

tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos))
tf.summary.scalar(
'num_of_clipped_rhos',
tf.reduce_sum(
tf.cast(
tf.equal(clipped_rhos, clip_rho_threshold), tf.int32)))
tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos))
else:
clipped_rhos = rhos

Expand Down Expand Up @@ -384,16 +298,3 @@ def scanfunc(acc, sequence_item):
return VTraceReturns(
vs=tf.stop_gradient(vs),
pg_advantages=tf.stop_gradient(pg_advantages))


def get_log_rhos(behaviour_action_log_probs, target_action_log_probs):
"""With the selected log_probs for multi-discrete actions of behaviour
and target policies we compute the log_rhos for calculating the vtrace."""
log_rhos = [
t - b
for t, b in zip(target_action_log_probs, behaviour_action_log_probs)
]
log_rhos = [tf.convert_to_tensor(l, dtype=tf.float32) for l in log_rhos]
log_rhos = tf.reduce_sum(tf.stack(log_rhos), axis=0)

return log_rhos
Loading

0 comments on commit 3c41cb9

Please sign in to comment.