Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi discrete #4338

Merged
merged 4 commits into from
Mar 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,6 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--stop='{"timesteps_total": 40000}' \
--ray-object-store-memory=500000000 \
--config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 64, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/agents/impala/vtrace_test.py
171 changes: 133 additions & 38 deletions python/ray/rllib/agents/impala/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
by Espeholt, Soyer, Munos et al.

See https://arxiv.org/abs/1802.01561 for the full paper.

In addition to the original paper's code, changes have been made
to support MultiDiscrete action spaces. behaviour_policy_logits,
target_policy_logits and actions parameters in the entry point
multi_from_logits method accepts lists of tensors instead of just
tensors.
"""

from __future__ import absolute_import
Expand All @@ -41,29 +47,48 @@


def log_probs_from_logits_and_actions(policy_logits, actions):
return multi_log_probs_from_logits_and_actions([policy_logits],
[actions])[0]


def multi_log_probs_from_logits_and_actions(policy_logits, actions):
"""Computes action log-probs from policy logits and actions.

In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions.
ACTION_SPACE refers to the list of numbers each representing a number of
actions.

Args:
policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parameterizing a softmax policy.
actions: An int32 tensor of shape [T, B] with actions.
policy_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing a softmax policy.
actions: A list with length of ACTION_SPACE of int32
tensors of shapes
[T, B],
...,
[T, B]
with actions.

Returns:
A float32 tensor of shape [T, B] corresponding to the sampling log
probability of the chosen action w.r.t. the policy.
A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B],
...,
[T, B]
corresponding to the sampling log probability
of the chosen action w.r.t. the policy.
"""
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)

policy_logits.shape.assert_has_rank(3)
actions.shape.assert_has_rank(2)
log_probs = []
for i in range(len(policy_logits)):
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits[i], labels=actions[i]))

return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits, labels=actions)
return log_probs


def from_logits(behaviour_policy_logits,
Expand All @@ -76,6 +101,39 @@ def from_logits(behaviour_policy_logits,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name='vtrace_from_logits'):
"""multi_from_logits wrapper used only for tests"""

res = multi_from_logits(
[behaviour_policy_logits], [target_policy_logits], [actions],
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
name=name)

return VTraceFromLogitsReturns(
vs=res.vs,
pg_advantages=res.pg_advantages,
log_rhos=res.log_rhos,
behaviour_action_log_probs=tf.squeeze(
res.behaviour_action_log_probs, axis=0),
target_action_log_probs=tf.squeeze(
res.target_action_log_probs, axis=0),
)


def multi_from_logits(behaviour_policy_logits,
target_policy_logits,
actions,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name='vtrace_from_logits'):
r"""V-trace for softmax policies.

Calculates V-trace actor critic targets for softmax polices as described in
Expand All @@ -90,16 +148,30 @@ def from_logits(behaviour_policy_logits,

In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions.
ACTION_SPACE refers to the list of numbers each representing a number of
actions.

Args:
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parametrizing the softmax behaviour
behaviour_policy_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing the softmax behaviour
policy.
target_policy_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
with un-normalized log-probabilities parameterizing the softmax target
policy.
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parametrizing the softmax target policy.
actions: An int32 tensor of shape [T, B] of actions sampled from the
behaviour policy.
actions: A list with length of ACTION_SPACE of int32
tensors of shapes
[T, B],
...,
[T, B]
with actions sampled from the behaviour policy.
discounts: A float32 tensor of shape [T, B] with the discount encountered
when following the behaviour policy.
rewards: A float32 tensor of shape [T, B] with the rewards generated by
Expand Down Expand Up @@ -128,29 +200,34 @@ def from_logits(behaviour_policy_logits,
target_action_log_probs: A float32 tensor of shape [T, B] containing
target policy action probabilities (log \pi(a_t)).
"""
behaviour_policy_logits = tf.convert_to_tensor(
behaviour_policy_logits, dtype=tf.float32)
target_policy_logits = tf.convert_to_tensor(
target_policy_logits, dtype=tf.float32)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy_logits.shape.assert_has_rank(3)
target_policy_logits.shape.assert_has_rank(3)
actions.shape.assert_has_rank(2)

for i in range(len(behaviour_policy_logits)):
behaviour_policy_logits[i] = tf.convert_to_tensor(
behaviour_policy_logits[i], dtype=tf.float32)
target_policy_logits[i] = tf.convert_to_tensor(
target_policy_logits[i], dtype=tf.float32)
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy_logits[i].shape.assert_has_rank(3)
target_policy_logits[i].shape.assert_has_rank(3)
actions[i].shape.assert_has_rank(2)

with tf.name_scope(
name,
values=[
behaviour_policy_logits, target_policy_logits, actions,
discounts, rewards, values, bootstrap_value
]):
target_action_log_probs = log_probs_from_logits_and_actions(
target_action_log_probs = multi_log_probs_from_logits_and_actions(
target_policy_logits, actions)
behaviour_action_log_probs = log_probs_from_logits_and_actions(
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions)
log_rhos = target_action_log_probs - behaviour_action_log_probs

log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)

vtrace_returns = from_importance_weights(
log_rhos=log_rhos,
discounts=discounts,
Expand All @@ -159,6 +236,7 @@ def from_logits(behaviour_policy_logits,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)

return VTraceFromLogitsReturns(
log_rhos=log_rhos,
behaviour_action_log_probs=behaviour_action_log_probs,
Expand All @@ -183,13 +261,13 @@ def from_importance_weights(log_rhos,
by Espeholt, Soyer, Munos et al.

In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions. This code also supports the
case where all tensors have the same number of additional dimensions, e.g.,
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].
time dimension ranging from 0 to T-1. B refers to the batch size. This code
also supports the case where all tensors have the same number of additional
dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
`bootstrap_value` is [B, C].

Args:
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the
log_rhos: A float32 tensor of shape [T, B] representing the
log importance sampling weights, i.e.
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
on rhos in log-space for numerical stability.
Expand Down Expand Up @@ -246,6 +324,14 @@ def from_importance_weights(log_rhos,
if clip_rho_threshold is not None:
clipped_rhos = tf.minimum(
clip_rho_threshold, rhos, name='clipped_rhos')

tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos))
tf.summary.scalar(
'num_of_clipped_rhos',
tf.reduce_sum(
tf.cast(
tf.equal(clipped_rhos, clip_rho_threshold), tf.int32)))
tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos))
else:
clipped_rhos = rhos

Expand Down Expand Up @@ -298,3 +384,12 @@ def scanfunc(acc, sequence_item):
return VTraceReturns(
vs=tf.stop_gradient(vs),
pg_advantages=tf.stop_gradient(pg_advantages))


def get_log_rhos(target_action_log_probs, behaviour_action_log_probs):
"""With the selected log_probs for multi-discrete actions of behaviour
and target policies we compute the log_rhos for calculating the vtrace."""
t = tf.stack(target_action_log_probs)
b = tf.stack(behaviour_action_log_probs)
log_rhos = tf.reduce_sum(t - b, axis=0)
return log_rhos
Loading