Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Move all config validation logic into AlgorithmConfig classes. #29854

Merged
merged 39 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
141efa8
wip
sven1977 Oct 28, 2022
82d5dae
Merge branch 'master' of https://github.com/ray-project/ray into algo…
sven1977 Oct 28, 2022
31fa93a
LINT
sven1977 Oct 28, 2022
4a0f858
wip
sven1977 Oct 28, 2022
833be55
wip
sven1977 Oct 28, 2022
028bfd6
wip
sven1977 Oct 28, 2022
a2d99d4
wip
sven1977 Oct 31, 2022
f8ac69e
Merge branch 'master' of https://github.com/ray-project/ray into algo…
sven1977 Oct 31, 2022
78d439a
Merge branch 'master' of https://github.com/ray-project/ray into algo…
sven1977 Oct 31, 2022
0f67ef5
wip
sven1977 Oct 31, 2022
d68c480
wip
sven1977 Oct 31, 2022
1e832b6
wip
sven1977 Oct 31, 2022
e7a5660
wip
sven1977 Oct 31, 2022
0801205
wip
sven1977 Oct 31, 2022
4134dab
wip
sven1977 Oct 31, 2022
213f8ae
Merge branch 'algo_configs_next_steps_3' into algo_configs_next_steps_4
sven1977 Oct 31, 2022
cc83c43
wip
sven1977 Oct 31, 2022
575f17d
wip
sven1977 Oct 31, 2022
35ca9e5
wip
sven1977 Oct 31, 2022
942c2ca
wip
sven1977 Oct 31, 2022
7ceba17
wip
sven1977 Oct 31, 2022
b6f94b0
wip
sven1977 Nov 1, 2022
8f94401
wip
sven1977 Nov 2, 2022
aea75e1
wip
sven1977 Nov 2, 2022
9493090
wip
sven1977 Nov 2, 2022
e502cb5
wip
sven1977 Nov 2, 2022
26821ce
wip
sven1977 Nov 2, 2022
db1e957
Merge branch 'master' of https://github.com/ray-project/ray into algo…
sven1977 Nov 2, 2022
dd9d557
Merge branch 'master' of https://github.com/ray-project/ray into algo…
sven1977 Nov 2, 2022
067cbd9
wip
sven1977 Nov 2, 2022
9d300e2
wip
sven1977 Nov 2, 2022
001cfe8
wip
sven1977 Nov 2, 2022
a126727
wip
sven1977 Nov 2, 2022
c0bd63a
Merge branch 'master' of https://github.com/ray-project/ray into algo…
sven1977 Nov 3, 2022
f1af880
wip
sven1977 Nov 3, 2022
9475568
wip
sven1977 Nov 4, 2022
d7228cd
wip
sven1977 Nov 4, 2022
2ccab68
wip
sven1977 Nov 4, 2022
a9a7913
wip
sven1977 Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "PG",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
"--config", "'{\"framework\": \"tf\", \"train_batch_size\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
]
)

Expand All @@ -1431,7 +1431,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "PG",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
"--config", "'{\"framework\": \"tf\", \"train_batch_size\": 5000, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
]
)

Expand All @@ -1444,7 +1444,7 @@ py_test(
"--env", "Pong-v0",
"--run", "PG",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1}'"
"--config", "'{\"framework\": \"tf\", \"train_batch_size\": 500, \"num_workers\": 1}'"
]
)

Expand Down Expand Up @@ -1718,14 +1718,14 @@ py_test(
py_test(
name = "evaluation/tests/test_rollout_worker",
tags = ["team:rllib", "evaluation"],
size = "medium",
size = "large",
srcs = ["evaluation/tests/test_rollout_worker.py"]
)

py_test(
name = "evaluation/tests/test_trajectory_view_api",
tags = ["team:rllib", "evaluation"],
size = "medium",
size = "large",
srcs = ["evaluation/tests/test_trajectory_view_api.py"]
)

Expand Down
81 changes: 32 additions & 49 deletions rllib/algorithms/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ray.rllib.utils.typing import (
PartialAlgorithmConfigDict,
ResultDict,
AlgorithmConfigDict,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(self):

# Override some of A3CConfig's default values with A2C-specific values.
self.num_rollout_workers = 2
self.rollout_fragment_length = 20
self.rollout_fragment_length = "auto"
self.sample_async = False
self.min_time_s_per_iteration = 10
# __sphinx_doc_end__
Expand Down Expand Up @@ -106,64 +105,48 @@ def training(

return self


class A2C(A3C):
@classmethod
@override(A3C)
def get_default_config(cls) -> AlgorithmConfig:
return A2CConfig()

@override(A3C)
def validate_config(self, config: AlgorithmConfigDict) -> None:
@override(A3CConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate_config(config)
super().validate()

if config["microbatch_size"]:
# Train batch size needs to be significantly larger than microbatch_size.
if config["train_batch_size"] / config["microbatch_size"] < 3:
if self.microbatch_size:
if self.num_gpus > 1:
raise AttributeError(
"A2C does not support multiple GPUs when micro-batching is set."
)

# Train batch size needs to be significantly larger than microbatch
# size.
if self.train_batch_size / self.microbatch_size < 3:
logger.warning(
"`train_batch_size` should be considerably larger (at least 3x) "
"than `microbatch_size` for a microbatching setup to make sense!"
"`train_batch_size` should be considerably larger (at least 3x)"
" than `microbatch_size` for a microbatching setup to make "
"sense!"
)
# Rollout fragment length needs to be less than microbatch_size.
if config["rollout_fragment_length"] > config["microbatch_size"]:
if (
self.rollout_fragment_length != "auto"
and self.rollout_fragment_length > self.microbatch_size
):
logger.warning(
"`rollout_fragment_length` should not be larger than "
"`microbatch_size` (try setting them to the same value)! "
"Otherwise, microbatches of desired size won't be achievable."
)

if config["num_gpus"] > 1:
raise AttributeError(
"A2C does not support multiple GPUs when micro-batching is set."
)
else:
sample_batch_size = (
config["rollout_fragment_length"]
* config["num_workers"]
* config["num_envs_per_worker"]
)
if config["train_batch_size"] < sample_batch_size:
logger.warning(
f"`train_batch_size` ({config['train_batch_size']}) "
"cannot be smaller than sample_batch_size "
"(`rollout_fragment_length` x `num_workers` x "
f"`num_envs_per_worker`) ({sample_batch_size}) when micro-batching"
" is not set. This is to"
" ensure that only on gradient update is applied to policy in every"
" iteration on the entire collected batch. As a result of we do not"
" change the policy too much before we sample again and stay on"
" policy as much as possible. This will help the learning"
" stability."
f" Setting train_batch_size = {sample_batch_size}."
)
config["train_batch_size"] = sample_batch_size
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
if self.rollout_fragment_length == "auto":
if self.microbatch_size:
return self.microbatch_size
return super().get_rollout_fragment_length(worker_index)

if "sgd_minibatch_size" in config:
raise AttributeError(
"A2C does not support sgd mini batching as it will instabilize the"
" training. Use `train_batch_size` instead."
)

class A2C(A3C):
@classmethod
@override(A3C)
def get_default_config(cls) -> AlgorithmConfig:
return A2CConfig()

@override(Algorithm)
def setup(self, config: PartialAlgorithmConfigDict):
Expand All @@ -190,7 +173,7 @@ def training_step(self) -> ResultDict:
# apply the averaged gradient in one SGD step. This conserves GPU
# memory, allowing for extremely large experience batches to be
# used.
if self._by_agent_steps:
if self.config.count_steps_by == "agent_steps":
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_agent_steps=self.config["microbatch_size"]
)
Expand Down
26 changes: 14 additions & 12 deletions rllib/algorithms/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
)
Expand Down Expand Up @@ -151,6 +150,16 @@ def training(

return self

@override(AlgorithmConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()

if self.entropy_coeff < 0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
if self.num_rollout_workers <= 0 and self.sample_async:
raise ValueError("`num_workers` for A3C must be >= 1!")


class A3C(Algorithm):
@classmethod
Expand All @@ -165,18 +174,11 @@ def setup(self, config: PartialAlgorithmConfigDict):
self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=1
)

@classmethod
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["num_workers"] <= 0 and config["sample_async"]:
raise ValueError("`num_workers` for A3C must be >= 1!")

@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy

Expand Down
Loading