-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wingman -> rllib] IMPALA MultiDiscrete changes #3967
Changes from 7 commits
a6b1b7a
088985b
23029d3
3b38ebb
1858404
9840eb6
e48f9ae
26eed71
3171c8a
9d62dd1
6597295
1d31991
252f6b3
5ef2e30
cf5c1c5
54f4f79
38c1896
2dd604f
7cb1f97
65c82d4
946df01
b6b2c52
56fe32e
1d46d7a
76046f3
d017c9f
c02b9f5
fbbed63
bcb2113
e020527
63e119a
6e06ba6
f161edf
1dbed08
aa50f98
57594c0
3f4883b
d32d253
967db5c
1da470f
9584a7c
8999621
0cbeb7c
abe797a
280b21c
e0e3060
afb462f
eb18cff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,14 +33,14 @@ | |
nest = tf.contrib.framework.nest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should update the file comment to say modified to support MultiDiscrete spaces. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [ | ||
'vs', 'pg_advantages', 'log_rhos', 'behaviour_action_log_probs', | ||
'target_action_log_probs' | ||
'vs', 'pg_advantages', 'rhos', 'behaviour_action_policy', | ||
'target_action_policy' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep these in log space? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. |
||
]) | ||
|
||
VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages') | ||
|
||
|
||
def log_probs_from_logits_and_actions(policy_logits, actions): | ||
def select_policy_values_using_actions(policy_logits, actions): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should update doc comment here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
"""Computes action log-probs from policy logits and actions. | ||
|
||
In the notation used throughout documentation and comments, T refers to the | ||
|
@@ -56,18 +56,17 @@ def log_probs_from_logits_and_actions(policy_logits, actions): | |
A float32 tensor of shape [T, B] corresponding to the sampling log | ||
probability of the chosen action w.r.t. the policy. | ||
""" | ||
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32) | ||
actions = tf.convert_to_tensor(actions, dtype=tf.int32) | ||
|
||
policy_logits.shape.assert_has_rank(3) | ||
actions.shape.assert_has_rank(2) | ||
log_probs = [] | ||
for i in range(len(policy_logits)): | ||
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
logits=policy_logits[i], labels=actions[i])) | ||
|
||
return -tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
logits=policy_logits, labels=actions) | ||
return log_probs | ||
|
||
|
||
def from_logits(behaviour_policy_logits, | ||
target_policy_logits, | ||
def from_logits(behaviour_policy, | ||
target_policy, | ||
actions, | ||
discounts, | ||
rewards, | ||
|
@@ -93,10 +92,10 @@ def from_logits(behaviour_policy_logits, | |
NUM_ACTIONS refers to the number of actions. | ||
|
||
Args: | ||
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
behaviour_policy: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should update doc comments here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
un-normalized log-probabilities parametrizing the softmax behaviour | ||
policy. | ||
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
target_policy: A float32 tensor of shape [T, B, NUM_ACTIONS] with | ||
un-normalized log-probabilities parametrizing the softmax target policy. | ||
actions: An int32 tensor of shape [T, B] of actions sampled from the | ||
behaviour policy. | ||
|
@@ -128,45 +127,47 @@ def from_logits(behaviour_policy_logits, | |
target_action_log_probs: A float32 tensor of shape [T, B] containing | ||
target policy action probabilities (log \pi(a_t)). | ||
""" | ||
behaviour_policy_logits = tf.convert_to_tensor( | ||
behaviour_policy_logits, dtype=tf.float32) | ||
target_policy_logits = tf.convert_to_tensor( | ||
target_policy_logits, dtype=tf.float32) | ||
actions = tf.convert_to_tensor(actions, dtype=tf.int32) | ||
|
||
# Make sure tensor ranks are as expected. | ||
# The rest will be checked by from_action_log_probs. | ||
behaviour_policy_logits.shape.assert_has_rank(3) | ||
target_policy_logits.shape.assert_has_rank(3) | ||
actions.shape.assert_has_rank(2) | ||
|
||
with tf.name_scope( | ||
name, | ||
values=[ | ||
behaviour_policy_logits, target_policy_logits, actions, | ||
discounts, rewards, values, bootstrap_value | ||
]): | ||
target_action_log_probs = log_probs_from_logits_and_actions( | ||
target_policy_logits, actions) | ||
behaviour_action_log_probs = log_probs_from_logits_and_actions( | ||
behaviour_policy_logits, actions) | ||
log_rhos = target_action_log_probs - behaviour_action_log_probs | ||
for i in range(len(behaviour_policy)): | ||
behaviour_policy[i] = tf.convert_to_tensor( | ||
behaviour_policy[i], dtype=tf.float32) | ||
target_policy[i] = tf.convert_to_tensor( | ||
target_policy[i], dtype=tf.float32) | ||
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32) | ||
|
||
# Make sure tensor ranks are as expected. | ||
# The rest will be checked by from_action_log_probs. | ||
behaviour_policy[i].shape.assert_has_rank(3) | ||
target_policy[i].shape.assert_has_rank(3) | ||
actions[i].shape.assert_has_rank(2) | ||
|
||
with tf.name_scope(name, values=[behaviour_policy, target_policy, actions, | ||
discounts, rewards, values, | ||
bootstrap_value]): | ||
target_action_policy = select_policy_values_using_actions( | ||
target_policy, actions) | ||
behaviour_action_policy = select_policy_values_using_actions( | ||
behaviour_policy, actions) | ||
|
||
rhos = get_rhos(target_action_policy, behaviour_action_policy) | ||
|
||
vtrace_returns = from_importance_weights( | ||
log_rhos=log_rhos, | ||
rhos=rhos, | ||
discounts=discounts, | ||
rewards=rewards, | ||
values=values, | ||
bootstrap_value=bootstrap_value, | ||
clip_rho_threshold=clip_rho_threshold, | ||
clip_pg_rho_threshold=clip_pg_rho_threshold) | ||
|
||
return VTraceFromLogitsReturns( | ||
log_rhos=log_rhos, | ||
behaviour_action_log_probs=behaviour_action_log_probs, | ||
target_action_log_probs=target_action_log_probs, | ||
rhos=rhos, | ||
behaviour_action_policy=behaviour_action_policy, | ||
target_action_policy=target_action_policy, | ||
**vtrace_returns._asdict()) | ||
|
||
|
||
def from_importance_weights(log_rhos, | ||
def from_importance_weights(rhos, | ||
discounts, | ||
rewards, | ||
values, | ||
|
@@ -189,10 +190,8 @@ def from_importance_weights(log_rhos, | |
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C]. | ||
|
||
Args: | ||
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the | ||
log importance sampling weights, i.e. | ||
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations | ||
on rhos in log-space for numerical stability. | ||
rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the | ||
importance sampling weights, i.e. target_policy(a) / behaviour_policy(a). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change it from log space? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We were experimenting with both log prob and prob spaces, will return it to log space. |
||
discounts: A float32 tensor of shape [T, B] with discounts encountered when | ||
following the behaviour policy. | ||
rewards: A float32 tensor of shape [T, B] containing rewards generated by | ||
|
@@ -216,9 +215,10 @@ def from_importance_weights(log_rhos, | |
pg_advantages: A float32 tensor of shape [T, B]. Can be used as the | ||
advantage in the calculation of policy gradients. | ||
""" | ||
log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32) | ||
rhos = tf.convert_to_tensor(rhos, dtype=tf.float32) | ||
discounts = tf.convert_to_tensor(discounts, dtype=tf.float32) | ||
rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) | ||
rewards = tf.cast(rewards, dtype=tf.float32) | ||
values = tf.convert_to_tensor(values, dtype=tf.float32) | ||
bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32) | ||
if clip_rho_threshold is not None: | ||
|
@@ -229,7 +229,7 @@ def from_importance_weights(log_rhos, | |
clip_pg_rho_threshold, dtype=tf.float32) | ||
|
||
# Make sure tensor ranks are consistent. | ||
rho_rank = log_rhos.shape.ndims # Usually 2. | ||
rho_rank = rhos.shape.ndims # Usually 2. | ||
values.shape.assert_has_rank(rho_rank) | ||
bootstrap_value.shape.assert_has_rank(rho_rank - 1) | ||
discounts.shape.assert_has_rank(rho_rank) | ||
|
@@ -241,14 +241,24 @@ def from_importance_weights(log_rhos, | |
|
||
with tf.name_scope( | ||
name, | ||
values=[log_rhos, discounts, rewards, values, bootstrap_value]): | ||
rhos = tf.exp(log_rhos) | ||
values=[rhos, discounts, rewards, values, bootstrap_value]): | ||
if clip_rho_threshold is not None: | ||
clipped_rhos = tf.minimum( | ||
clip_rho_threshold, rhos, name='clipped_rhos') | ||
else: | ||
clipped_rhos = rhos | ||
|
||
tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos)) | ||
tf.summary.scalar( | ||
'num_of_clipped_rhos', | ||
tf.reduce_sum( | ||
tf.cast( | ||
tf.equal( | ||
clipped_rhos, | ||
clip_rho_threshold), | ||
tf.int32))) | ||
tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos)) | ||
|
||
cs = tf.minimum(1.0, rhos, name='cs') | ||
# Append bootstrapped value to get [v1, ..., v_t+1] | ||
values_t_plus_1 = tf.concat( | ||
|
@@ -298,3 +308,14 @@ def scanfunc(acc, sequence_item): | |
return VTraceReturns( | ||
vs=tf.stop_gradient(vs), | ||
pg_advantages=tf.stop_gradient(pg_advantages)) | ||
|
||
|
||
def get_rhos(behaviour_action_log_probs, target_action_log_probs): | ||
"""With the selected policy values (logits or probs) subclasses compute | ||
the rhos for calculating the vtrace.""" | ||
log_rhos = [t - b for t, | ||
b in zip(target_action_log_probs, behaviour_action_log_probs)] | ||
log_rhos = [tf.convert_to_tensor(l, dtype=tf.float32) for l in log_rhos] | ||
log_rhos = tf.reduce_sum(tf.stack(log_rhos), axis=0) | ||
|
||
return tf.exp(log_rhos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't needed right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, as default is DiagGaussian, and we use Categorical (which is MultiCategorical for MultiDiscrete action space). We didn't change the default, which is DiagGaussian and Discrete action space (that's how IMPALA operates at the moment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused, isn't it Categorical for discrete spaces?
I don't think DiagGaussian ever gets used in IMPALA does it (maybe you meant APPO?)