Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix A2C release test crash (rollout_fragment_length vs train_batch_size). #30361

Merged
merged 11 commits into from
Nov 21, 2022
6 changes: 4 additions & 2 deletions release/rllib_tests/app_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ python:

post_build_cmds:
- pip3 uninstall -y ray || true && pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
# TODO(jungong): remove once nightly image gets upgraded.
- pip install -U pybullet==3.2.0
- {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }}
# Clone the rl-experiments repo for offline-RL files.
- git clone https://github.com/ray-project/rl-experiments.git
- cp rl-experiments/halfcheetah-sac/2021-09-06/halfcheetah_expert_sac.zip ~/.
# Use torch+CUDA10.2 for our release tests. CUDA11.x has known performance issues in combination with torch+GPU+CNNs
# TODO(sven): remove once nightly image gets upgraded.
- pip3 install torch==1.12.1+cu102 torchvision==0.13.1+cu102 --extra-index-url https://download.pytorch.org/whl/cu102

4 changes: 4 additions & 0 deletions rllib/algorithms/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

# Synchronous sampling, on-policy PG algo -> Check mismatches between
# `rollout_fragment_length` and `train_batch_size` to avoid user confusion.
self.validate_train_batch_size_vs_rollout_fragment_length()

if self.microbatch_size:
if self.num_gpus > 1:
raise AttributeError(
Expand Down
51 changes: 51 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,57 @@ def is_policy_to_train(pid, batch=None):

return policies, is_policy_to_train

def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stupid questions: 1) why do we need to be able to specify both that roughly match each other? why not just error out when train_batch_size does not match the expected value based on the rollout_fragment_length value? 2) is setting rollout_fragment_length = "auto" always recommended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, for off-policy algos, rollout_fragment_length can be whatever it wants and it is not linked to the train batch size. For on-policy, I'm thinking that sometimes, users would want to set the rollout fragment length manually to force a certain rollout behavior, however, through this new error, we force them to be aware that this will have an effect on their train batch size.

"""Detects mismatches for `train_batch_size` vs `rollout_fragment_length`.

Only applicable for algorithms, whose train_batch_size should be directly
dependent on rollout_fragment_length (synchronous sampling, on-policy PG algos).

If rollout_fragment_length != "auto", makes sure that the product of
`rollout_fragment_length` x `num_rollout_workers` x `num_envs_per_worker`
roughly (10%) matches the provided `train_batch_size`. Otherwise, errors with
asking the user to set rollout_fragment_length to `auto` or to a matching
value.

Also, only checks this if `train_batch_size` > 0 (DDPPO sets this
to -1 to auto-calculate the actual batch size later).

Raises:
ValueError: If there is a mismatch between user provided
`rollout_fragment_length` and `train_batch_size`.
"""
if (
self.rollout_fragment_length != "auto"
and not self.in_evaluation
and self.train_batch_size > 0
):
min_batch_size = (
max(self.num_rollout_workers, 1)
* self.num_envs_per_worker
* self.rollout_fragment_length
)
batch_size = min_batch_size
while batch_size < self.train_batch_size:
batch_size += min_batch_size
if (
batch_size - self.train_batch_size > 0.1 * self.train_batch_size
or batch_size - min_batch_size - self.train_batch_size
> (0.1 * self.train_batch_size)
):
suggested_rollout_fragment_length = self.train_batch_size // (
self.num_envs_per_worker * (self.num_rollout_workers or 1)
)
raise ValueError(
f"Your desired `train_batch_size` ({self.train_batch_size}) or a "
"value 10% off of that cannot be achieved with your other "
f"settings (num_rollout_workers={self.num_rollout_workers}; "
f"num_envs_per_worker={self.num_envs_per_worker}; "
f"rollout_fragment_length={self.rollout_fragment_length})! "
"Try setting `rollout_fragment_length` to 'auto' OR "
f"{suggested_rollout_fragment_length}."
)


def __setattr__(self, key, value):
"""Gatekeeper in case we are in frozen state and need to error."""

Expand Down
37 changes: 3 additions & 34 deletions rllib/algorithms/pg/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,40 +90,9 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

# Check for mismatches between `train_batch_size` and
# `rollout_fragment_length` (if not "auto")..
# Note: Only check this if `train_batch_size` > 0 (DDPPO sets this
# to -1 to auto-calculate the actual batch size later).
if (
self.rollout_fragment_length != "auto"
and not self.in_evaluation
and self.train_batch_size > 0
):
min_batch_size = (
max(self.num_rollout_workers, 1)
* self.num_envs_per_worker
* self.rollout_fragment_length
)
batch_size = min_batch_size
while batch_size < self.train_batch_size:
batch_size += min_batch_size
if (
batch_size - self.train_batch_size > 0.1 * self.train_batch_size
or batch_size - min_batch_size - self.train_batch_size
> (0.1 * self.train_batch_size)
):
suggested_rollout_fragment_length = self.train_batch_size // (
self.num_envs_per_worker * (self.num_rollout_workers or 1)
)
raise ValueError(
f"Your desired `train_batch_size` ({self.train_batch_size}) or a "
"value 10% off of that cannot be achieved with your other "
f"settings (num_rollout_workers={self.num_rollout_workers}; "
f"num_envs_per_worker={self.num_envs_per_worker}; "
f"rollout_fragment_length={self.rollout_fragment_length})! "
"Try setting `rollout_fragment_length` to 'auto' OR "
f"{suggested_rollout_fragment_length}."
)
# Synchronous sampling, on-policy PG algo -> Check mismatches between
# `rollout_fragment_length` and `train_batch_size` to avoid user confusion.
self.validate_train_batch_size_vs_rollout_fragment_length()


class PG(Algorithm):
Expand Down
3 changes: 2 additions & 1 deletion rllib/tuned_examples/a2c/atari-a2c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ atari-a2c:
config:
# Works for both torch and tf.
framework: tf
rollout_fragment_length: 20
train_batch_size: 500
rollout_fragment_length: auto
clip_rewards: True
num_workers: 5
num_envs_per_worker: 5
Expand Down
2 changes: 2 additions & 0 deletions rllib/tuned_examples/a2c/cartpole-a2c-fake-gpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ cartpole-a2c-fake-gpus:
config:
# Works for both torch and tf.
framework: tf
train_batch_size: 20
rollout_fragment_length: auto
num_workers: 0
lr: 0.001
# Fake 2 GPUs.
Expand Down
2 changes: 2 additions & 0 deletions rllib/tuned_examples/a2c/cartpole-a2c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ cartpole-a2c:
config:
# Works for both torch and tf.
framework: tf
train_batch_size: 40
rollout_fragment_length: auto
num_workers: 0
lr: 0.001
2 changes: 1 addition & 1 deletion rllib/tuned_examples/a2c/cartpole_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
config = (
A2CConfig()
.environment("CartPole-v1")
.training(lr=0.001)
.training(lr=0.001, train_batch_size=20)
.framework("tf")
.rollouts(num_rollout_workers=0)
)