Skip to content

Commit

Permalink
[RLlib] Optionally don't drop last ts in v-trace calculations (APPO a…
Browse files Browse the repository at this point in the history
…nd IMPALA). (#19601)
  • Loading branch information
sven1977 authored Nov 3, 2021
1 parent cf21c63 commit e6ae08f
Show file tree
Hide file tree
Showing 13 changed files with 137 additions and 59 deletions.
12 changes: 8 additions & 4 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -352,21 +352,25 @@
--test_tag_filters=learning_tests_continuous,-fake_gpus,-torch_only,-flaky
--test_arg=--framework=tf
rllib/...
- label: ":brain: RLlib: Learning discr. actions TF2-eager (from rllib/tuned_examples/*.yaml)"
- label: ":brain: RLlib: Learning discr. actions TF2-eager-tracing (from rllib/tuned_examples/*.yaml)"
conditions: ["RAY_CI_RLLIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
- RLLIB_TESTING=1 ./ci/travis/install-dependencies.sh
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/travis/install-dependencies.sh
# Because Python version changed, we need to re-install Ray here
- rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/travis/ci.sh build
- bazel test --config=ci $(./scripts/bazel_export_options)
--build_tests_only
--test_tag_filters=learning_tests_discrete,-fake_gpus,-torch_only,-flaky,-multi_gpu,-no_tf_eager_tracing
--test_arg=--framework=tf2
rllib/...
- label: ":brain: RLlib: Learning cont. actions TF2-eager (from rllib/tuned_examples/*.yaml)"
- label: ":brain: RLlib: Learning cont. actions TF2-eager-tracing (from rllib/tuned_examples/*.yaml)"
conditions: ["RAY_CI_RLLIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
- RLLIB_TESTING=1 ./ci/travis/install-dependencies.sh
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/travis/install-dependencies.sh
# Because Python version changed, we need to re-install Ray here
- rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/travis/ci.sh build
- bazel test --config=ci $(./scripts/bazel_export_options)
--build_tests_only
--test_tag_filters=learning_tests_continuous,-fake_gpus,-torch_only,-flaky,-multi_gpu
Expand Down
2 changes: 1 addition & 1 deletion ci/travis/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ install_dependencies() {

# RLlib testing with TF 1.x.
if [ "${RLLIB_TESTING-}" = 1 ] && { [ -n "${TF_VERSION-}" ] || [ -n "${TFP_VERSION-}" ]; }; then
pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}" gym==0.19
pip install --upgrade tensorflow-probability=="${TFP_VERSION}" tensorflow=="${TF_VERSION}"
fi

# Additional Tune dependency for Horovod.
Expand Down
10 changes: 10 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/ppo"]
)

py_test(
name = "run_regression_tests_frozenlake_appo",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/frozenlake-appo-vtrace.yaml"],
args = ["--yaml-dir=tuned_examples/ppo"]
)

py_test(
name = "learning_cartpole_appo_fake_gpus",
main = "tests/run_regression_tests.py",
Expand Down
7 changes: 7 additions & 0 deletions rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
"vtrace": True,
"vtrace_clip_rho_threshold": 1.0,
"vtrace_clip_pg_rho_threshold": 1.0,
# If True, drop the last timestep for the vtrace calculations, such that
# all data goes into the calculations as [B x T-1] (+ the bootstrap value).
# This is the default and legacy RLlib behavior, however, could potentially
# have a destabilizing effect on learning, especially in sparse reward
# or reward-at-goal environments.
# False for not dropping the last timestep.
"vtrace_drop_last_ts": True,
# System params.
#
# == Overview of data flow in IMPALA ==
Expand Down
27 changes: 15 additions & 12 deletions rllib/agents/impala/vtrace_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,27 @@ def make_time_major(*args, **kw):
loss_actions = actions if is_multidiscrete else tf.expand_dims(
actions, axis=1)

# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
# Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
drop_last = policy.config["vtrace_drop_last_ts"]
policy.loss = VTraceLoss(
actions=make_time_major(loss_actions, drop_last=True),
actions=make_time_major(loss_actions, drop_last=drop_last),
actions_logp=make_time_major(
action_dist.logp(actions), drop_last=True),
action_dist.logp(actions), drop_last=drop_last),
actions_entropy=make_time_major(
action_dist.multi_entropy(), drop_last=True),
dones=make_time_major(dones, drop_last=True),
action_dist.multi_entropy(), drop_last=drop_last),
dones=make_time_major(dones, drop_last=drop_last),
behaviour_action_logp=make_time_major(
behaviour_action_logp, drop_last=True),
behaviour_action_logp, drop_last=drop_last),
behaviour_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=make_time_major(unpacked_outputs, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_logits=make_time_major(unpacked_outputs, drop_last=drop_last),
discount=policy.config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
rewards=make_time_major(rewards, drop_last=drop_last),
values=make_time_major(values, drop_last=drop_last),
bootstrap_value=make_time_major(values)[-1],
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
valid_mask=make_time_major(mask, drop_last=True),
valid_mask=make_time_major(mask, drop_last=drop_last),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
Expand All @@ -232,11 +233,13 @@ def make_time_major(*args, **kw):


def stats(policy, train_batch):
drop_last = policy.config["vtrace"] and \
policy.config["vtrace_drop_last_ts"]
values_batched = _make_time_major(
policy,
train_batch.get(SampleBatch.SEQ_LENS),
policy.model.value_function(),
drop_last=policy.config["vtrace"])
drop_last=drop_last)

return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
Expand Down
25 changes: 13 additions & 12 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,27 @@ def _make_time_major(*args, **kw):
loss_actions = actions if is_multidiscrete else torch.unsqueeze(
actions, dim=1)

# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
# Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
drop_last = policy.config["vtrace_drop_last_ts"]
loss = VTraceLoss(
actions=_make_time_major(loss_actions, drop_last=True),
actions=_make_time_major(loss_actions, drop_last=drop_last),
actions_logp=_make_time_major(
action_dist.logp(actions), drop_last=True),
action_dist.logp(actions), drop_last=drop_last),
actions_entropy=_make_time_major(
action_dist.entropy(), drop_last=True),
dones=_make_time_major(dones, drop_last=True),
action_dist.entropy(), drop_last=drop_last),
dones=_make_time_major(dones, drop_last=drop_last),
behaviour_action_logp=_make_time_major(
behaviour_action_logp, drop_last=True),
behaviour_action_logp, drop_last=drop_last),
behaviour_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=_make_time_major(unpacked_outputs, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last),
discount=policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=True),
values=_make_time_major(values, drop_last=True),
rewards=_make_time_major(rewards, drop_last=drop_last),
values=_make_time_major(values, drop_last=drop_last),
bootstrap_value=_make_time_major(values)[-1],
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
valid_mask=_make_time_major(mask, drop_last=True),
valid_mask=_make_time_major(mask, drop_last=drop_last),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
Expand All @@ -196,7 +197,7 @@ def _make_time_major(*args, **kw):
policy,
train_batch.get(SampleBatch.SEQ_LENS),
values,
drop_last=policy.config["vtrace"])
drop_last=policy.config["vtrace"] and drop_last)
model.tower_stats["vf_explained_var"] = explained_variance(
torch.reshape(loss.value_targets, [-1]),
torch.reshape(values_batched, [-1]))
Expand Down
35 changes: 22 additions & 13 deletions rllib/agents/ppo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def reduce_mean_valid(t):
reduce_mean_valid = tf.reduce_mean

if policy.config["vtrace"]:
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
drop_last = policy.config["vtrace_drop_last_ts"]
logger.debug("Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})")

# Prepare actions for loss.
loss_actions = actions if is_multidiscrete else tf.expand_dims(
Expand All @@ -155,7 +157,7 @@ def reduce_mean_valid(t):

# Prepare KL for Loss
mean_kl = make_time_major(
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last)

unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1)
Expand All @@ -166,16 +168,19 @@ def reduce_mean_valid(t):
with tf.device("/cpu:0"):
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=make_time_major(
unpacked_behaviour_logits, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_policy_logits=make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=True),
unpacked_old_policy_behaviour_logits, drop_last=drop_last),
actions=tf.unstack(
make_time_major(loss_actions, drop_last=True), axis=2),
make_time_major(loss_actions, drop_last=drop_last),
axis=2),
discounts=tf.cast(
~make_time_major(tf.cast(dones, tf.bool), drop_last=True),
~make_time_major(
tf.cast(dones, tf.bool), drop_last=drop_last),
tf.float32) * policy.config["gamma"],
rewards=make_time_major(rewards, drop_last=True),
values=values_time_major[:-1], # drop-last=True
rewards=make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1]
if drop_last else values_time_major,
bootstrap_value=values_time_major[-1],
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
Expand All @@ -186,11 +191,11 @@ def reduce_mean_valid(t):
)

actions_logp = make_time_major(
action_dist.logp(actions), drop_last=True)
action_dist.logp(actions), drop_last=drop_last)
prev_actions_logp = make_time_major(
prev_action_dist.logp(actions), drop_last=True)
prev_action_dist.logp(actions), drop_last=drop_last)
old_policy_actions_logp = make_time_major(
old_policy_action_dist.logp(actions), drop_last=True)
old_policy_action_dist.logp(actions), drop_last=drop_last)

is_ratio = tf.clip_by_value(
tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
Expand All @@ -210,7 +215,10 @@ def reduce_mean_valid(t):
mean_policy_loss = -reduce_mean_valid(surrogate_loss)

# The value function loss.
delta = values_time_major[:-1] - vtrace_returns.vs
if drop_last:
delta = values_time_major[:-1] - vtrace_returns.vs
else:
delta = values_time_major - vtrace_returns.vs
value_targets = vtrace_returns.vs
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

Expand Down Expand Up @@ -294,7 +302,8 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
policy,
train_batch.get(SampleBatch.SEQ_LENS),
policy.model.value_function(),
drop_last=policy.config["vtrace"])
drop_last=policy.config["vtrace"]
and policy.config["vtrace_drop_last_ts"])

stats_dict = {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
Expand Down
40 changes: 23 additions & 17 deletions rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ def _make_time_major(*args, **kwargs):
values = model.value_function()
values_time_major = _make_time_major(values)

drop_last = policy.config["vtrace"] and \
policy.config["vtrace_drop_last_ts"]

if policy.is_recurrent():
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = torch.reshape(mask, [-1])
mask = _make_time_major(mask, drop_last=policy.config["vtrace"])
mask = _make_time_major(mask, drop_last=drop_last)
num_valid = torch.sum(mask)

def reduce_mean_valid(t):
Expand All @@ -99,7 +102,8 @@ def reduce_mean_valid(t):
reduce_mean_valid = torch.mean

if policy.config["vtrace"]:
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
logger.debug("Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})")

old_policy_behaviour_logits = target_model_out.detach()
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
Expand All @@ -121,20 +125,20 @@ def reduce_mean_valid(t):

# Prepare KL for loss.
action_kl = _make_time_major(
old_policy_action_dist.kl(action_dist), drop_last=True)
old_policy_action_dist.kl(action_dist), drop_last=drop_last)

# Compute vtrace on the CPU for better perf.
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=True),
unpacked_behaviour_logits, drop_last=drop_last),
target_policy_logits=_make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=True),
unpacked_old_policy_behaviour_logits, drop_last=drop_last),
actions=torch.unbind(
_make_time_major(loss_actions, drop_last=True), dim=2),
discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) *
policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=True),
values=values_time_major[:-1], # drop-last=True
_make_time_major(loss_actions, drop_last=drop_last), dim=2),
discounts=(1.0 - _make_time_major(
dones, drop_last=drop_last).float()) * policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1] if drop_last else values_time_major,
bootstrap_value=values_time_major[-1],
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
Expand All @@ -143,11 +147,11 @@ def reduce_mean_valid(t):
"vtrace_clip_pg_rho_threshold"])

actions_logp = _make_time_major(
action_dist.logp(actions), drop_last=True)
action_dist.logp(actions), drop_last=drop_last)
prev_actions_logp = _make_time_major(
prev_action_dist.logp(actions), drop_last=True)
prev_action_dist.logp(actions), drop_last=drop_last)
old_policy_actions_logp = _make_time_major(
old_policy_action_dist.logp(actions), drop_last=True)
old_policy_action_dist.logp(actions), drop_last=drop_last)
is_ratio = torch.clamp(
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
Expand All @@ -165,12 +169,15 @@ def reduce_mean_valid(t):

# The value function loss.
value_targets = vtrace_returns.vs.to(values_time_major.device)
delta = values_time_major[:-1] - value_targets
if drop_last:
delta = values_time_major[:-1] - value_targets
else:
delta = values_time_major - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

# The entropy loss.
mean_entropy = reduce_mean_valid(
_make_time_major(action_dist.entropy(), drop_last=True))
_make_time_major(action_dist.entropy(), drop_last=drop_last))

else:
logger.debug("Using PPO surrogate loss (vtrace=False)")
Expand Down Expand Up @@ -222,8 +229,7 @@ def reduce_mean_valid(t):
model.tower_stats["vf_explained_var"] = explained_variance(
torch.reshape(value_targets, [-1]),
torch.reshape(
values_time_major[:-1]
if policy.config["vtrace"] else values_time_major, [-1]),
values_time_major[:-1] if drop_last else values_time_major, [-1]),
)

return total_loss
Expand Down
2 changes: 2 additions & 0 deletions rllib/tuned_examples/impala/cartpole-impala-fake-gpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ cartpole-impala-fake-gpus:
# Fake 2 GPUs.
num_gpus: 2
_fake_gpus: true

vtrace_drop_last_ts: false
1 change: 1 addition & 0 deletions rllib/tuned_examples/impala/cartpole-impala.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ cartpole-impala:
# Works for both torch and tf.
framework: tf
num_gpus: 0
vtrace_drop_last_ts: false
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cartpole-appo-vtrace-fake-gpus:
num_sgd_iter: 6
vf_loss_coeff: 0.01
vtrace: true
vtrace_drop_last_ts: false

# Double batch size (2 GPUs).
train_batch_size: 1000
Expand Down
1 change: 1 addition & 0 deletions rllib/tuned_examples/ppo/cartpole-appo-vtrace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cartpole-appo-vtrace:
num_sgd_iter: 6
vf_loss_coeff: 0.01
vtrace: true
vtrace_drop_last_ts: false
model:
fcnet_hiddens: [32]
fcnet_activation: linear
Expand Down
Loading

0 comments on commit e6ae08f

Please sign in to comment.