-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wingman -> rllib] IMPALA MultiDiscrete changes #3967
Merged
Merged
Changes from 33 commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
a6b1b7a
impala changes
088985b
fixed newlines
23029d3
Merge branch 'master' into impala
3b38ebb
reformatting impalla.py
1858404
aligned vtrace.py formatting some more
9840eb6
aligned formatting some more
e48f9ae
aligned formatting some more
26eed71
Merge branch 'master' into impala
3171c8a
fixed impala stuff
9d62dd1
Address vtrace comments (#6)
pimpke 6597295
Made APPO work with VTrace
stefanpantic 1d31991
Variable is no longer a member
stefanpantic 252f6b3
Optimized imports
stefanpantic 5ef2e30
Changed is_discrete to is_multidiscrete, fixed KL distribution
stefanpantic cf5c1c5
Fixed KL divergence
stefanpantic 54f4f79
Removed if statement
stefanpantic 38c1896
Merge branch 'master' into impala
2dd604f
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
7cb1f97
revert appo file
65c82d4
revered stefans appo changes
946df01
old appo policy graph
b6b2c52
returned stefan appo changes and returned newline
56fe32e
fixed newlines in appo_policy_graph
1d46d7a
Merge branch 'master' into impala
76046f3
aligned with action_dist changes in ray master
d017c9f
small appo fixes
c02b9f5
add vtrace test
ericl fbbed63
fix appo impala integration
ericl bcb2113
add to jenkins
ericl e020527
merged with master
63e119a
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
6e06ba6
fixing appo policy graph changes
f161edf
fixed vtrace tests
1dbed08
lint and py2 compat
ericl aa50f98
kl
ericl 57594c0
Merge branch 'master' into impala
3f4883b
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
d32d253
removed dist_type as it is actually not needed for IMPALA
967db5c
fixing issue with new gym version
1da470f
Merge branch 'master' into impala
9584a7c
lint
ericl 8999621
fix multigpu test
ericl 0cbeb7c
merged with master
abe797a
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
280b21c
Merge branch 'master' into impala
e0e3060
Merge branch 'master' into impala
afb462f
Merge remote-tracking branch 'upstream/master' into impala
ericl eb18cff
fix tests
ericl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,12 @@ | |
by Espeholt, Soyer, Munos et al. | ||
|
||
See https://arxiv.org/abs/1802.01561 for the full paper. | ||
|
||
In addition to the original paper's code, changes have been made | ||
to support MultiDiscrete action spaces. behaviour_policy_logits, | ||
target_policy_logits and actions parameters in the entry point | ||
multi_from_logits method accepts lists of tensors instead of just | ||
tensors. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
|
@@ -41,29 +47,47 @@ | |
|
||
|
||
def log_probs_from_logits_and_actions(policy_logits, actions): | ||
return multi_log_probs_from_logits_and_actions( | ||
[policy_logits], [actions])[0] | ||
|
||
|
||
def multi_log_probs_from_logits_and_actions(policy_logits, actions): | ||
"""Computes action log-probs from policy logits and actions. | ||
|
||
In the notation used throughout documentation and comments, T refers to the | ||
time dimension ranging from 0 to T-1. B refers to the batch size and | ||
NUM_ACTIONS refers to the number of actions. | ||
ACTION_SPACE refers to the list of numbers each representing a number of actions. | ||
|
||
Args: | ||
policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
un-normalized log-probabilities parameterizing a softmax policy. | ||
actions: An int32 tensor of shape [T, B] with actions. | ||
policy_logits: A list with length of ACTION_SPACE of float32 | ||
tensors of shapes | ||
[T, B, ACTION_SPACE[0]], | ||
..., | ||
[T, B, ACTION_SPACE[-1]] | ||
with un-normalized log-probabilities parameterizing a softmax policy. | ||
actions: A list with length of ACTION_SPACE of int32 | ||
tensors of shapes | ||
[T, B], | ||
..., | ||
[T, B] | ||
with actions. | ||
|
||
Returns: | ||
A float32 tensor of shape [T, B] corresponding to the sampling log | ||
probability of the chosen action w.r.t. the policy. | ||
A list with length of ACTION_SPACE of float32 | ||
tensors of shapes | ||
[T, B], | ||
..., | ||
[T, B] | ||
corresponding to the sampling log probability | ||
of the chosen action w.r.t. the policy. | ||
""" | ||
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32) | ||
actions = tf.convert_to_tensor(actions, dtype=tf.int32) | ||
|
||
policy_logits.shape.assert_has_rank(3) | ||
actions.shape.assert_has_rank(2) | ||
log_probs = [] | ||
for i in range(len(policy_logits)): | ||
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
logits=policy_logits[i], labels=actions[i])) | ||
|
||
return -tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
logits=policy_logits, labels=actions) | ||
return log_probs | ||
|
||
|
||
def from_logits(behaviour_policy_logits, | ||
|
@@ -76,6 +100,40 @@ def from_logits(behaviour_policy_logits, | |
clip_rho_threshold=1.0, | ||
clip_pg_rho_threshold=1.0, | ||
name='vtrace_from_logits'): | ||
"""multi_from_logits wrapper used only for tests""" | ||
|
||
res = multi_from_logits( | ||
[behaviour_policy_logits], | ||
[target_policy_logits], | ||
[actions], | ||
discounts, | ||
rewards, | ||
values, | ||
bootstrap_value, | ||
clip_rho_threshold=clip_rho_threshold, | ||
clip_pg_rho_threshold=clip_pg_rho_threshold, | ||
name=name) | ||
|
||
return VTraceFromLogitsReturns( | ||
vs = res.vs, | ||
pg_advantages = res.pg_advantages, | ||
log_rhos=res.log_rhos, | ||
behaviour_action_log_probs=tf.squeeze(res.behaviour_action_log_probs, axis=0), | ||
target_action_log_probs=tf.squeeze(res.target_action_log_probs, axis=0), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, thanks for fixing this bit. |
||
|
||
|
||
def multi_from_logits( | ||
behaviour_policy_logits, | ||
target_policy_logits, | ||
actions, | ||
discounts, | ||
rewards, | ||
values, | ||
bootstrap_value, | ||
clip_rho_threshold=1.0, | ||
clip_pg_rho_threshold=1.0, | ||
name='vtrace_from_logits'): | ||
r"""V-trace for softmax policies. | ||
|
||
Calculates V-trace actor critic targets for softmax polices as described in | ||
|
@@ -90,16 +148,27 @@ def from_logits(behaviour_policy_logits, | |
|
||
In the notation used throughout documentation and comments, T refers to the | ||
time dimension ranging from 0 to T-1. B refers to the batch size and | ||
NUM_ACTIONS refers to the number of actions. | ||
ACTION_SPACE refers to the list of numbers each representing a number of actions. | ||
|
||
Args: | ||
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
un-normalized log-probabilities parametrizing the softmax behaviour | ||
policy. | ||
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
un-normalized log-probabilities parametrizing the softmax target policy. | ||
actions: An int32 tensor of shape [T, B] of actions sampled from the | ||
behaviour policy. | ||
behaviour_policy_logits: A list with length of ACTION_SPACE of float32 | ||
tensors of shapes | ||
[T, B, ACTION_SPACE[0]], | ||
..., | ||
[T, B, ACTION_SPACE[-1]] | ||
with un-normalized log-probabilities parameterizing the softmax behaviour policy. | ||
target_policy_logits: A list with length of ACTION_SPACE of float32 | ||
tensors of shapes | ||
[T, B, ACTION_SPACE[0]], | ||
..., | ||
[T, B, ACTION_SPACE[-1]] | ||
with un-normalized log-probabilities parameterizing the softmax target policy. | ||
actions: A list with length of ACTION_SPACE of int32 | ||
tensors of shapes | ||
[T, B], | ||
..., | ||
[T, B] | ||
with actions sampled from the behaviour policy. | ||
discounts: A float32 tensor of shape [T, B] with the discount encountered | ||
when following the behaviour policy. | ||
rewards: A float32 tensor of shape [T, B] with the rewards generated by | ||
|
@@ -128,29 +197,31 @@ def from_logits(behaviour_policy_logits, | |
target_action_log_probs: A float32 tensor of shape [T, B] containing | ||
target policy action probabilities (log \pi(a_t)). | ||
""" | ||
behaviour_policy_logits = tf.convert_to_tensor( | ||
behaviour_policy_logits, dtype=tf.float32) | ||
target_policy_logits = tf.convert_to_tensor( | ||
target_policy_logits, dtype=tf.float32) | ||
actions = tf.convert_to_tensor(actions, dtype=tf.int32) | ||
|
||
# Make sure tensor ranks are as expected. | ||
# The rest will be checked by from_action_log_probs. | ||
behaviour_policy_logits.shape.assert_has_rank(3) | ||
target_policy_logits.shape.assert_has_rank(3) | ||
actions.shape.assert_has_rank(2) | ||
|
||
with tf.name_scope( | ||
name, | ||
values=[ | ||
behaviour_policy_logits, target_policy_logits, actions, | ||
discounts, rewards, values, bootstrap_value | ||
]): | ||
target_action_log_probs = log_probs_from_logits_and_actions( | ||
for i in range(len(behaviour_policy_logits)): | ||
behaviour_policy_logits[i] = tf.convert_to_tensor( | ||
behaviour_policy_logits[i], dtype=tf.float32) | ||
target_policy_logits[i] = tf.convert_to_tensor( | ||
target_policy_logits[i], dtype=tf.float32) | ||
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32) | ||
|
||
# Make sure tensor ranks are as expected. | ||
# The rest will be checked by from_action_log_probs. | ||
behaviour_policy_logits[i].shape.assert_has_rank(3) | ||
target_policy_logits[i].shape.assert_has_rank(3) | ||
actions[i].shape.assert_has_rank(2) | ||
|
||
with tf.name_scope(name, values=[behaviour_policy_logits, target_policy_logits, actions, | ||
discounts, rewards, values, | ||
bootstrap_value]): | ||
target_action_log_probs = multi_log_probs_from_logits_and_actions( | ||
target_policy_logits, actions) | ||
behaviour_action_log_probs = log_probs_from_logits_and_actions( | ||
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( | ||
behaviour_policy_logits, actions) | ||
log_rhos = target_action_log_probs - behaviour_action_log_probs | ||
|
||
log_rhos = get_log_rhos(target_action_log_probs, | ||
behaviour_action_log_probs) | ||
|
||
vtrace_returns = from_importance_weights( | ||
log_rhos=log_rhos, | ||
discounts=discounts, | ||
|
@@ -159,6 +230,7 @@ def from_logits(behaviour_policy_logits, | |
bootstrap_value=bootstrap_value, | ||
clip_rho_threshold=clip_rho_threshold, | ||
clip_pg_rho_threshold=clip_pg_rho_threshold) | ||
|
||
return VTraceFromLogitsReturns( | ||
log_rhos=log_rhos, | ||
behaviour_action_log_probs=behaviour_action_log_probs, | ||
|
@@ -183,13 +255,13 @@ def from_importance_weights(log_rhos, | |
by Espeholt, Soyer, Munos et al. | ||
|
||
In the notation used throughout documentation and comments, T refers to the | ||
time dimension ranging from 0 to T-1. B refers to the batch size and | ||
NUM_ACTIONS refers to the number of actions. This code also supports the | ||
case where all tensors have the same number of additional dimensions, e.g., | ||
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C]. | ||
time dimension ranging from 0 to T-1. B refers to the batch size. This code | ||
also supports the case where all tensors have the same number of additional | ||
dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C], | ||
`bootstrap_value` is [B, C]. | ||
|
||
Args: | ||
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the | ||
log_rhos: A float32 tensor of shape [T, B] representing the | ||
log importance sampling weights, i.e. | ||
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations | ||
on rhos in log-space for numerical stability. | ||
|
@@ -246,6 +318,14 @@ def from_importance_weights(log_rhos, | |
if clip_rho_threshold is not None: | ||
clipped_rhos = tf.minimum( | ||
clip_rho_threshold, rhos, name='clipped_rhos') | ||
|
||
tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos)) | ||
tf.summary.scalar( | ||
'num_of_clipped_rhos', | ||
tf.reduce_sum(tf.cast( | ||
tf.equal(clipped_rhos, clip_rho_threshold), tf.int32)) | ||
) | ||
tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos)) | ||
else: | ||
clipped_rhos = rhos | ||
|
||
|
@@ -298,3 +378,14 @@ def scanfunc(acc, sequence_item): | |
return VTraceReturns( | ||
vs=tf.stop_gradient(vs), | ||
pg_advantages=tf.stop_gradient(pg_advantages)) | ||
|
||
|
||
def get_log_rhos(behaviour_action_log_probs, target_action_log_probs): | ||
"""With the selected log_probs for multi-discrete actions of behaviour | ||
and target policies we compute the log_rhos for calculating the vtrace.""" | ||
log_rhos = [t - b for t, | ||
b in zip(target_action_log_probs, behaviour_action_log_probs)] | ||
log_rhos = [tf.convert_to_tensor(l, dtype=tf.float32) for l in log_rhos] | ||
log_rhos = tf.reduce_sum(tf.stack(log_rhos), axis=0) | ||
|
||
return log_rhos |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't needed right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, as default is DiagGaussian, and we use Categorical (which is MultiCategorical for MultiDiscrete action space). We didn't change the default, which is DiagGaussian and Discrete action space (that's how IMPALA operates at the moment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused, isn't it Categorical for discrete spaces?
I don't think DiagGaussian ever gets used in IMPALA does it (maybe you meant APPO?)