Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Optionally don't drop last ts in v-trace calculations (APPO and IMPALA). #19601

Merged
merged 23 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -352,21 +352,25 @@
--test_tag_filters=learning_tests_continuous,-fake_gpus,-torch_only,-flaky
--test_arg=--framework=tf
rllib/...
- label: ":brain: RLlib: Learning discr. actions TF2-eager (from rllib/tuned_examples/*.yaml)"
- label: ":brain: RLlib: Learning discr. actions TF2-eager-tracing (from rllib/tuned_examples/*.yaml)"
conditions: ["RAY_CI_RLLIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
- RLLIB_TESTING=1 ./ci/travis/install-dependencies.sh
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/travis/install-dependencies.sh
# Because Python version changed, we need to re-install Ray here
- rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/travis/ci.sh build
- bazel test --config=ci $(./scripts/bazel_export_options)
--build_tests_only
--test_tag_filters=learning_tests_discrete,-fake_gpus,-torch_only,-flaky,-multi_gpu,-no_tf_eager_tracing
--test_arg=--framework=tf2
rllib/...
- label: ":brain: RLlib: Learning cont. actions TF2-eager (from rllib/tuned_examples/*.yaml)"
- label: ":brain: RLlib: Learning cont. actions TF2-eager-tracing (from rllib/tuned_examples/*.yaml)"
conditions: ["RAY_CI_RLLIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
- RLLIB_TESTING=1 ./ci/travis/install-dependencies.sh
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/travis/install-dependencies.sh
# Because Python version changed, we need to re-install Ray here
- rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/travis/ci.sh build
- bazel test --config=ci $(./scripts/bazel_export_options)
--build_tests_only
--test_tag_filters=learning_tests_continuous,-fake_gpus,-torch_only,-flaky,-multi_gpu
Expand Down
2 changes: 1 addition & 1 deletion ci/travis/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ install_dependencies() {

# RLlib testing with TF 1.x.
if [ "${RLLIB_TESTING-}" = 1 ] && { [ -n "${TF_VERSION-}" ] || [ -n "${TFP_VERSION-}" ]; }; then
pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym==0.19
pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}"
fi

# Additional Tune dependency for Horovod.
Expand Down
10 changes: 10 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/ppo"]
)

py_test(
name = "run_regression_tests_frozenlake_appo",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/frozenlake-appo-vtrace.yaml"],
args = ["--yaml-dir=tuned_examples/ppo"]
)

py_test(
name = "learning_cartpole_appo_fake_gpus",
main = "tests/run_regression_tests.py",
Expand Down
7 changes: 7 additions & 0 deletions rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
"vtrace": True,
"vtrace_clip_rho_threshold": 1.0,
"vtrace_clip_pg_rho_threshold": 1.0,
# If True, drop the last timestep for the vtrace calculations, such that
# all data goes into the calculations as [B x T-1] (+ the bootstrap value).
# This is the default and legacy RLlib behavior, however, could potentially
# have a destabilizing effect on learning, especially in sparse reward
# or reward-at-goal environments.
# False for not dropping the last timestep.
"vtrace_drop_last_ts": True,
# System params.
#
# == Overview of data flow in IMPALA ==
Expand Down
27 changes: 15 additions & 12 deletions rllib/agents/impala/vtrace_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,27 @@ def make_time_major(*args, **kw):
loss_actions = actions if is_multidiscrete else tf.expand_dims(
actions, axis=1)

# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
# Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
drop_last = policy.config["vtrace_drop_last_ts"]
policy.loss = VTraceLoss(
actions=make_time_major(loss_actions, drop_last=True),
actions=make_time_major(loss_actions, drop_last=drop_last),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
action_dist.logp(actions), drop_last=drop_last),
actions_entropy=make_time_major(
action_dist.multi_entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
action_dist.multi_entropy(), drop_last=drop_last),
dones=make_time_major(dones, drop_last=drop_last),
behaviour_action_logp=make_time_major(
behaviour_action_logp, drop_last=True),
behaviour_action_logp, drop_last=drop_last),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=make_time_major(unpacked_outputs, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_logits=make_time_major(unpacked_outputs, drop_last=drop_last),
discount=policy.config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
rewards=make_time_major(rewards, drop_last=drop_last),
values=make_time_major(values, drop_last=drop_last),
bootstrap_value=make_time_major(values)[-1],
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
valid_mask=make_time_major(mask, drop_last=True),
valid_mask=make_time_major(mask, drop_last=drop_last),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
Expand All @@ -232,11 +233,13 @@ def make_time_major(*args, **kw):


def stats(policy, train_batch):
drop_last = policy.config["vtrace"] and \
policy.config["vtrace_drop_last_ts"]
values_batched = _make_time_major(
policy,
train_batch.get(SampleBatch.SEQ_LENS),
policy.model.value_function(),
drop_last=policy.config["vtrace"])
drop_last=drop_last)

return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
Expand Down
25 changes: 13 additions & 12 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,27 @@ def _make_time_major(*args, **kw):
loss_actions = actions if is_multidiscrete else torch.unsqueeze(
actions, dim=1)

# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
# Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
drop_last = policy.config["vtrace_drop_last_ts"]
loss = VTraceLoss(
actions=_make_time_major(loss_actions, drop_last=True),
actions=_make_time_major(loss_actions, drop_last=drop_last),
actions_logp=_make_time_major(
action_dist.logp(actions), drop_last=True),
action_dist.logp(actions), drop_last=drop_last),
actions_entropy=_make_time_major(
action_dist.entropy(), drop_last=True),
dones=_make_time_major(dones, drop_last=True),
action_dist.entropy(), drop_last=drop_last),
dones=_make_time_major(dones, drop_last=drop_last),
behaviour_action_logp=_make_time_major(
behaviour_action_logp, drop_last=True),
behaviour_action_logp, drop_last=drop_last),
behaviour_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=_make_time_major(unpacked_outputs, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last),
discount=policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=True),
values=_make_time_major(values, drop_last=True),
rewards=_make_time_major(rewards, drop_last=drop_last),
values=_make_time_major(values, drop_last=drop_last),
bootstrap_value=_make_time_major(values)[-1],
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
valid_mask=_make_time_major(mask, drop_last=True),
valid_mask=_make_time_major(mask, drop_last=drop_last),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
Expand All @@ -196,7 +197,7 @@ def _make_time_major(*args, **kw):
policy,
train_batch.get(SampleBatch.SEQ_LENS),
values,
drop_last=policy.config["vtrace"])
drop_last=policy.config["vtrace"] and drop_last)
model.tower_stats["vf_explained_var"] = explained_variance(
torch.reshape(loss.value_targets, [-1]),
torch.reshape(values_batched, [-1]))
Expand Down
35 changes: 22 additions & 13 deletions rllib/agents/ppo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def reduce_mean_valid(t):
reduce_mean_valid = tf.reduce_mean

if policy.config["vtrace"]:
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
drop_last = policy.config["vtrace_drop_last_ts"]
logger.debug("Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})")

# Prepare actions for loss.
loss_actions = actions if is_multidiscrete else tf.expand_dims(
Expand All @@ -155,7 +157,7 @@ def reduce_mean_valid(t):

# Prepare KL for Loss
mean_kl = make_time_major(
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last)

unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1)
Expand All @@ -166,16 +168,19 @@ def reduce_mean_valid(t):
with tf.device("/cpu:0"):
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_policy_logits=make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=True),
unpacked_old_policy_behaviour_logits, drop_last=drop_last),
actions=tf.unstack(
make_time_major(loss_actions, drop_last=True), axis=2),
make_time_major(loss_actions, drop_last=drop_last),
axis=2),
discounts=tf.cast(
~make_time_major(tf.cast(dones, tf.bool), drop_last=True),
~make_time_major(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's ~make_time_major mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~ means NOT.
make_time_major transforms a tensor of shape [B, T, ...] into [T, B, ...].

tf.cast(dones, tf.bool), drop_last=drop_last),
tf.float32) * policy.config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
values=values_time_major[:-1], # drop-last=True
rewards=make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1]
if drop_last else values_time_major,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation, 4 spaces in front?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, not sure. LINTer says it's ok. You mean the if stuff, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, since it's a continuation of last line. no idea, minor comment.

bootstrap_value=values_time_major[-1],
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
Expand All @@ -186,11 +191,11 @@ def reduce_mean_valid(t):
)

actions_logp = make_time_major(
action_dist.logp(actions), drop_last=True)
action_dist.logp(actions), drop_last=drop_last)
prev_actions_logp = make_time_major(
prev_action_dist.logp(actions), drop_last=True)
prev_action_dist.logp(actions), drop_last=drop_last)
old_policy_actions_logp = make_time_major(
old_policy_action_dist.logp(actions), drop_last=True)
old_policy_action_dist.logp(actions), drop_last=drop_last)

is_ratio = tf.clip_by_value(
tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
Expand All @@ -210,7 +215,10 @@ def reduce_mean_valid(t):
mean_policy_loss = -reduce_mean_valid(surrogate_loss)

# The value function loss.
delta = values_time_major[:-1] - vtrace_returns.vs
if drop_last:
delta = values_time_major[:-1] - vtrace_returns.vs
else:
delta = values_time_major - vtrace_returns.vs
value_targets = vtrace_returns.vs
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

Expand Down Expand Up @@ -294,7 +302,8 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
policy,
train_batch.get(SampleBatch.SEQ_LENS),
policy.model.value_function(),
drop_last=policy.config["vtrace"])
drop_last=policy.config["vtrace"]
and policy.config["vtrace_drop_last_ts"])

stats_dict = {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
Expand Down
40 changes: 23 additions & 17 deletions rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ def _make_time_major(*args, **kwargs):
values = model.value_function()
values_time_major = _make_time_major(values)

drop_last = policy.config["vtrace"] and \
policy.config["vtrace_drop_last_ts"]

if policy.is_recurrent():
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = torch.reshape(mask, [-1])
mask = _make_time_major(mask, drop_last=policy.config["vtrace"])
mask = _make_time_major(mask, drop_last=drop_last)
num_valid = torch.sum(mask)

def reduce_mean_valid(t):
Expand All @@ -99,7 +102,8 @@ def reduce_mean_valid(t):
reduce_mean_valid = torch.mean

if policy.config["vtrace"]:
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
logger.debug("Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})")

old_policy_behaviour_logits = target_model_out.detach()
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
Expand All @@ -121,20 +125,20 @@ def reduce_mean_valid(t):

# Prepare KL for loss.
action_kl = _make_time_major(
old_policy_action_dist.kl(action_dist), drop_last=True)
old_policy_action_dist.kl(action_dist), drop_last=drop_last)

# Compute vtrace on the CPU for better perf.
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_policy_logits=_make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=True),
unpacked_old_policy_behaviour_logits, drop_last=drop_last),
actions=torch.unbind(
_make_time_major(loss_actions, drop_last=True), dim=2),
discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) *
policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=True),
values=values_time_major[:-1], # drop-last=True
_make_time_major(loss_actions, drop_last=drop_last), dim=2),
discounts=(1.0 - _make_time_major(
dones, drop_last=drop_last).float()) * policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1] if drop_last else values_time_major,
bootstrap_value=values_time_major[-1],
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
Expand All @@ -143,11 +147,11 @@ def reduce_mean_valid(t):
"vtrace_clip_pg_rho_threshold"])

actions_logp = _make_time_major(
action_dist.logp(actions), drop_last=True)
action_dist.logp(actions), drop_last=drop_last)
prev_actions_logp = _make_time_major(
prev_action_dist.logp(actions), drop_last=True)
prev_action_dist.logp(actions), drop_last=drop_last)
old_policy_actions_logp = _make_time_major(
old_policy_action_dist.logp(actions), drop_last=True)
old_policy_action_dist.logp(actions), drop_last=drop_last)
is_ratio = torch.clamp(
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
Expand All @@ -165,12 +169,15 @@ def reduce_mean_valid(t):

# The value function loss.
value_targets = vtrace_returns.vs.to(values_time_major.device)
delta = values_time_major[:-1] - value_targets
if drop_last:
delta = values_time_major[:-1] - value_targets
else:
delta = values_time_major - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

# The entropy loss.
mean_entropy = reduce_mean_valid(
_make_time_major(action_dist.entropy(), drop_last=True))
_make_time_major(action_dist.entropy(), drop_last=drop_last))

else:
logger.debug("Using PPO surrogate loss (vtrace=False)")
Expand Down Expand Up @@ -222,8 +229,7 @@ def reduce_mean_valid(t):
model.tower_stats["vf_explained_var"] = explained_variance(
torch.reshape(value_targets, [-1]),
torch.reshape(
values_time_major[:-1]
if policy.config["vtrace"] else values_time_major, [-1]),
values_time_major[:-1] if drop_last else values_time_major, [-1]),
)

return total_loss
Expand Down
2 changes: 2 additions & 0 deletions rllib/tuned_examples/impala/cartpole-impala-fake-gpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ cartpole-impala-fake-gpus:
# Fake 2 GPUs.
num_gpus: 2
_fake_gpus: true

vtrace_drop_last_ts: false
1 change: 1 addition & 0 deletions rllib/tuned_examples/impala/cartpole-impala.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ cartpole-impala:
# Works for both torch and tf.
framework: tf
num_gpus: 0
vtrace_drop_last_ts: false
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cartpole-appo-vtrace-fake-gpus:
num_sgd_iter: 6
vf_loss_coeff: 0.01
vtrace: true
vtrace_drop_last_ts: false

# Double batch size (2 GPUs).
train_batch_size: 1000
Expand Down
1 change: 1 addition & 0 deletions rllib/tuned_examples/ppo/cartpole-appo-vtrace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cartpole-appo-vtrace:
num_sgd_iter: 6
vf_loss_coeff: 0.01
vtrace: true
vtrace_drop_last_ts: false
model:
fcnet_hiddens: [32]
fcnet_activation: linear
Expand Down
Loading