Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wingman -> rllib] IMPALA MultiDiscrete changes #3967

Merged
merged 48 commits into from
Mar 2, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
a6b1b7a
impala changes
Feb 6, 2019
088985b
fixed newlines
Feb 6, 2019
23029d3
Merge branch 'master' into impala
Feb 7, 2019
3b38ebb
reformatting impalla.py
Feb 7, 2019
1858404
aligned vtrace.py formatting some more
Feb 7, 2019
9840eb6
aligned formatting some more
Feb 7, 2019
e48f9ae
aligned formatting some more
Feb 7, 2019
26eed71
Merge branch 'master' into impala
Feb 8, 2019
3171c8a
fixed impala stuff
Feb 8, 2019
9d62dd1
Address vtrace comments (#6)
pimpke Feb 8, 2019
6597295
Made APPO work with VTrace
stefanpantic Feb 11, 2019
1d31991
Variable is no longer a member
stefanpantic Feb 11, 2019
252f6b3
Optimized imports
stefanpantic Feb 11, 2019
5ef2e30
Changed is_discrete to is_multidiscrete, fixed KL distribution
stefanpantic Feb 11, 2019
cf5c1c5
Fixed KL divergence
stefanpantic Feb 11, 2019
54f4f79
Removed if statement
stefanpantic Feb 11, 2019
38c1896
Merge branch 'master' into impala
Feb 11, 2019
2dd604f
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
Feb 11, 2019
7cb1f97
revert appo file
Feb 14, 2019
65c82d4
revered stefans appo changes
Feb 14, 2019
946df01
old appo policy graph
Feb 14, 2019
b6b2c52
returned stefan appo changes and returned newline
Feb 14, 2019
56fe32e
fixed newlines in appo_policy_graph
Feb 14, 2019
1d46d7a
Merge branch 'master' into impala
Feb 14, 2019
76046f3
aligned with action_dist changes in ray master
Feb 14, 2019
d017c9f
small appo fixes
Feb 14, 2019
c02b9f5
add vtrace test
ericl Feb 15, 2019
fbbed63
fix appo impala integration
ericl Feb 15, 2019
bcb2113
add to jenkins
ericl Feb 15, 2019
e020527
merged with master
Feb 18, 2019
63e119a
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
Feb 18, 2019
6e06ba6
fixing appo policy graph changes
Feb 18, 2019
f161edf
fixed vtrace tests
Feb 18, 2019
1dbed08
lint and py2 compat
ericl Feb 18, 2019
aa50f98
kl
ericl Feb 18, 2019
57594c0
Merge branch 'master' into impala
Feb 19, 2019
3f4883b
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
Feb 19, 2019
d32d253
removed dist_type as it is actually not needed for IMPALA
Feb 19, 2019
967db5c
fixing issue with new gym version
Feb 19, 2019
1da470f
Merge branch 'master' into impala
Feb 20, 2019
9584a7c
lint
ericl Feb 20, 2019
8999621
fix multigpu test
ericl Feb 20, 2019
0cbeb7c
merged with master
Feb 25, 2019
abe797a
Merge branch 'impala' of https://github.com/wingman-ai/ray into impala
Feb 25, 2019
280b21c
Merge branch 'master' into impala
Feb 26, 2019
e0e3060
Merge branch 'master' into impala
Feb 27, 2019
afb462f
Merge remote-tracking branch 'upstream/master' into impala
ericl Mar 1, 2019
eb18cff
fix tests
ericl Mar 1, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
# max number of workers to broadcast one set of weights to
"broadcast_interval": 1,

# Actions are chosen based on this distribution, if provided
"dist_type": None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't needed right?

Copy link
Contributor Author

@bjg2 bjg2 Feb 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, as default is DiagGaussian, and we use Categorical (which is MultiCategorical for MultiDiscrete action space). We didn't change the default, which is DiagGaussian and Discrete action space (that's how IMPALA operates at the moment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused, isn't it Categorical for discrete spaces?

        elif isinstance(action_space, gym.spaces.Discrete):
            return Categorical, action_space.n

I don't think DiagGaussian ever gets used in IMPALA does it (maybe you meant APPO?)


# Learning params.
"grad_clip": 40.0,
# either "adam" or "rmsprop"
Expand Down
117 changes: 69 additions & 48 deletions python/ray/rllib/agents/impala/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
nest = tf.contrib.framework.nest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update the file comment to say modified to support MultiDiscrete spaces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [
'vs', 'pg_advantages', 'log_rhos', 'behaviour_action_log_probs',
'target_action_log_probs'
'vs', 'pg_advantages', 'rhos', 'behaviour_action_policy',
'target_action_policy'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep these in log space?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

])

VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')


def log_probs_from_logits_and_actions(policy_logits, actions):
def select_policy_values_using_actions(policy_logits, actions):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update doc comment here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

"""Computes action log-probs from policy logits and actions.

In the notation used throughout documentation and comments, T refers to the
Expand All @@ -56,18 +56,17 @@ def log_probs_from_logits_and_actions(policy_logits, actions):
A float32 tensor of shape [T, B] corresponding to the sampling log
probability of the chosen action w.r.t. the policy.
"""
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)

policy_logits.shape.assert_has_rank(3)
actions.shape.assert_has_rank(2)
log_probs = []
for i in range(len(policy_logits)):
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits[i], labels=actions[i]))

return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=policy_logits, labels=actions)
return log_probs


def from_logits(behaviour_policy_logits,
target_policy_logits,
def from_logits(behaviour_policy,
target_policy,
actions,
discounts,
rewards,
Expand All @@ -93,10 +92,10 @@ def from_logits(behaviour_policy_logits,
NUM_ACTIONS refers to the number of actions.

Args:
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
behaviour_policy: A float32 tensor of shape [T, B, NUM_ACTIONS] with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update doc comments here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

un-normalized log-probabilities parametrizing the softmax behaviour
policy.
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
target_policy: A float32 tensor of shape [T, B, NUM_ACTIONS] with
un-normalized log-probabilities parametrizing the softmax target policy.
actions: An int32 tensor of shape [T, B] of actions sampled from the
behaviour policy.
Expand Down Expand Up @@ -128,45 +127,47 @@ def from_logits(behaviour_policy_logits,
target_action_log_probs: A float32 tensor of shape [T, B] containing
target policy action probabilities (log \pi(a_t)).
"""
behaviour_policy_logits = tf.convert_to_tensor(
behaviour_policy_logits, dtype=tf.float32)
target_policy_logits = tf.convert_to_tensor(
target_policy_logits, dtype=tf.float32)
actions = tf.convert_to_tensor(actions, dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy_logits.shape.assert_has_rank(3)
target_policy_logits.shape.assert_has_rank(3)
actions.shape.assert_has_rank(2)

with tf.name_scope(
name,
values=[
behaviour_policy_logits, target_policy_logits, actions,
discounts, rewards, values, bootstrap_value
]):
target_action_log_probs = log_probs_from_logits_and_actions(
target_policy_logits, actions)
behaviour_action_log_probs = log_probs_from_logits_and_actions(
behaviour_policy_logits, actions)
log_rhos = target_action_log_probs - behaviour_action_log_probs
for i in range(len(behaviour_policy)):
behaviour_policy[i] = tf.convert_to_tensor(
behaviour_policy[i], dtype=tf.float32)
target_policy[i] = tf.convert_to_tensor(
target_policy[i], dtype=tf.float32)
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32)

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
behaviour_policy[i].shape.assert_has_rank(3)
target_policy[i].shape.assert_has_rank(3)
actions[i].shape.assert_has_rank(2)

with tf.name_scope(name, values=[behaviour_policy, target_policy, actions,
discounts, rewards, values,
bootstrap_value]):
target_action_policy = select_policy_values_using_actions(
target_policy, actions)
behaviour_action_policy = select_policy_values_using_actions(
behaviour_policy, actions)

rhos = get_rhos(target_action_policy, behaviour_action_policy)

vtrace_returns = from_importance_weights(
log_rhos=log_rhos,
rhos=rhos,
discounts=discounts,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)

return VTraceFromLogitsReturns(
log_rhos=log_rhos,
behaviour_action_log_probs=behaviour_action_log_probs,
target_action_log_probs=target_action_log_probs,
rhos=rhos,
behaviour_action_policy=behaviour_action_policy,
target_action_policy=target_action_policy,
**vtrace_returns._asdict())


def from_importance_weights(log_rhos,
def from_importance_weights(rhos,
discounts,
rewards,
values,
Expand All @@ -189,10 +190,8 @@ def from_importance_weights(log_rhos,
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].

Args:
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the
log importance sampling weights, i.e.
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
on rhos in log-space for numerical stability.
rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the
importance sampling weights, i.e. target_policy(a) / behaviour_policy(a).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change it from log space?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were experimenting with both log prob and prob spaces, will return it to log space.

discounts: A float32 tensor of shape [T, B] with discounts encountered when
following the behaviour policy.
rewards: A float32 tensor of shape [T, B] containing rewards generated by
Expand All @@ -216,9 +215,10 @@ def from_importance_weights(log_rhos,
pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
advantage in the calculation of policy gradients.
"""
log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32)
rhos = tf.convert_to_tensor(rhos, dtype=tf.float32)
discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
rewards = tf.cast(rewards, dtype=tf.float32)
values = tf.convert_to_tensor(values, dtype=tf.float32)
bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)
if clip_rho_threshold is not None:
Expand All @@ -229,7 +229,7 @@ def from_importance_weights(log_rhos,
clip_pg_rho_threshold, dtype=tf.float32)

# Make sure tensor ranks are consistent.
rho_rank = log_rhos.shape.ndims # Usually 2.
rho_rank = rhos.shape.ndims # Usually 2.
values.shape.assert_has_rank(rho_rank)
bootstrap_value.shape.assert_has_rank(rho_rank - 1)
discounts.shape.assert_has_rank(rho_rank)
Expand All @@ -241,14 +241,24 @@ def from_importance_weights(log_rhos,

with tf.name_scope(
name,
values=[log_rhos, discounts, rewards, values, bootstrap_value]):
rhos = tf.exp(log_rhos)
values=[rhos, discounts, rewards, values, bootstrap_value]):
if clip_rho_threshold is not None:
clipped_rhos = tf.minimum(
clip_rho_threshold, rhos, name='clipped_rhos')
else:
clipped_rhos = rhos

tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos))
tf.summary.scalar(
'num_of_clipped_rhos',
tf.reduce_sum(
tf.cast(
tf.equal(
clipped_rhos,
clip_rho_threshold),
tf.int32)))
tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos))

cs = tf.minimum(1.0, rhos, name='cs')
# Append bootstrapped value to get [v1, ..., v_t+1]
values_t_plus_1 = tf.concat(
Expand Down Expand Up @@ -298,3 +308,14 @@ def scanfunc(acc, sequence_item):
return VTraceReturns(
vs=tf.stop_gradient(vs),
pg_advantages=tf.stop_gradient(pg_advantages))


def get_rhos(behaviour_action_log_probs, target_action_log_probs):
"""With the selected policy values (logits or probs) subclasses compute
the rhos for calculating the vtrace."""
log_rhos = [t - b for t,
b in zip(target_action_log_probs, behaviour_action_log_probs)]
log_rhos = [tf.convert_to_tensor(l, dtype=tf.float32) for l in log_rhos]
log_rhos = tf.reduce_sum(tf.stack(log_rhos), axis=0)

return tf.exp(log_rhos)
Loading