-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] APPO config objects #24376
Merged
Merged
[RLlib] APPO config objects #24376
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
d91ff3a
wip
sven1977 da26714
wip
sven1977 21bf876
wip
sven1977 fa5faf2
Merge branch 'impala_config_objects' into appo_config_objects
sven1977 c629956
wip
sven1977 7b17862
wip
sven1977 43adb98
wip
sven1977 b5e2c8f
wip
sven1977 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from ray.rllib.agents.impala.impala import DEFAULT_CONFIG, ImpalaTrainer | ||
from ray.rllib.agents.impala.impala import DEFAULT_CONFIG, ImpalaConfig, ImpalaTrainer | ||
|
||
__all__ = [ | ||
"DEFAULT_CONFIG", | ||
"ImpalaConfig", | ||
"ImpalaTrainer", | ||
"DEFAULT_CONFIG", | ||
] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,69 +23,140 @@ | |
_get_shared_metrics, | ||
) | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.deprecation import Deprecated | ||
from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict | ||
|
||
# fmt: off | ||
# __sphinx_doc_begin__ | ||
|
||
# Adds the following updates to the `IMPALATrainer` config in | ||
# rllib/agents/impala/impala.py. | ||
DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs( | ||
impala.DEFAULT_CONFIG, # See keys in impala.py, which are also supported. | ||
{ | ||
# Whether to use V-trace weighted advantages. If false, PPO GAE | ||
# advantages will be used instead. | ||
"vtrace": True, | ||
|
||
# == These two options only apply if vtrace: False == | ||
# Should use a critic as a baseline (otherwise don't use value | ||
# baseline; required for using GAE). | ||
"use_critic": True, | ||
# If true, use the Generalized Advantage Estimator (GAE) | ||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf. | ||
"use_gae": True, | ||
# GAE(lambda) parameter | ||
"lambda": 1.0, | ||
|
||
# == PPO surrogate loss options == | ||
"clip_param": 0.4, | ||
|
||
# == PPO KL Loss options == | ||
"use_kl_loss": False, | ||
"kl_coeff": 1.0, | ||
"kl_target": 0.01, | ||
|
||
# == IMPALA optimizer params (see documentation in impala.py) == | ||
"rollout_fragment_length": 50, | ||
"train_batch_size": 500, | ||
"min_time_s_per_reporting": 10, | ||
"num_workers": 2, | ||
"num_gpus": 0, | ||
"num_multi_gpu_tower_stacks": 1, | ||
"minibatch_buffer_size": 1, | ||
"num_sgd_iter": 1, | ||
"replay_proportion": 0.0, | ||
"replay_buffer_num_slots": 100, | ||
"learner_queue_size": 16, | ||
"learner_queue_timeout": 300, | ||
"max_sample_requests_in_flight_per_worker": 2, | ||
"broadcast_interval": 1, | ||
"grad_clip": 40.0, | ||
"opt_type": "adam", | ||
"lr": 0.0005, | ||
"lr_schedule": None, | ||
"decay": 0.99, | ||
"momentum": 0.0, | ||
"epsilon": 0.1, | ||
"vf_loss_coeff": 0.5, | ||
"entropy_coeff": 0.01, | ||
"entropy_coeff_schedule": None, | ||
}, | ||
_allow_unknown_configs=True, | ||
) | ||
|
||
# __sphinx_doc_end__ | ||
# fmt: on | ||
class APPOConfig(impala.ImpalaConfig): | ||
"""Defines a A2CTrainer configuration class from which a new Trainer can be built. | ||
|
||
Example: | ||
>>> from ray import tune | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. APPOConfig import |
||
>>> config = APPOConfig().training(lr=0.01, grad_clip=30.0)\ | ||
... .resources(num_gpus=1)\ | ||
... .rollouts(num_rollout_workers=16) | ||
>>> print(config.to_dict()) | ||
>>> # Build a Trainer object from the config and run 1 training iteration. | ||
>>> trainer = config.build(env="CartPole-v1") | ||
>>> trainer.train() | ||
|
||
Example: | ||
>>> config = APPOConfig() | ||
>>> # Print out some default values. | ||
>>> print(config.sample_async) | ||
>>> # Update the config object. | ||
>>> config.training(lr=tune.grid_search([0.001, 0.0001])) | ||
>>> # Set the config object's env. | ||
>>> config.environment(env="CartPole-v1") | ||
>>> # Use to_dict() to get the old-style python config dict | ||
>>> # when running with tune. | ||
>>> tune.run( | ||
... "APPO", | ||
... stop={"episode_reward_mean": 200}, | ||
... config=config.to_dict(), | ||
... ) | ||
""" | ||
|
||
def __init__(self, trainer_class=None): | ||
"""Initializes a APPOConfig instance.""" | ||
super().__init__(trainer_class=trainer_class or APPOTrainer) | ||
|
||
# fmt: off | ||
# __sphinx_doc_begin__ | ||
|
||
# APPO specific settings: | ||
self.vtrace = True | ||
self.use_critic = True | ||
self.use_gae = True | ||
self.lambda_ = 1.0 | ||
self.clip_param = 0.4 | ||
self.use_kl_loss = False | ||
self.kl_coeff = 1.0 | ||
self.kl_target = 0.01 | ||
|
||
# Override some of ImpalaConfig's default values with APPO-specific values. | ||
self.rollout_fragment_length = 50 | ||
self.train_batch_size = 500 | ||
self.min_time_s_per_reporting = 10 | ||
self.num_workers = 2 | ||
self.num_gpus = 0 | ||
self.num_multi_gpu_tower_stacks = 1 | ||
self.minibatch_buffer_size = 1 | ||
self.num_sgd_iter = 1 | ||
self.replay_proportion = 0.0 | ||
self.replay_buffer_num_slots = 100 | ||
self.learner_queue_size = 16 | ||
self.learner_queue_timeout = 300 | ||
self.max_sample_requests_in_flight_per_worker = 2 | ||
self.broadcast_interval = 1 | ||
self.grad_clip = 40.0 | ||
self.opt_type = "adam" | ||
self.lr = 0.0005 | ||
self.lr_schedule = None | ||
self.decay = 0.99 | ||
self.momentum = 0.0 | ||
self.epsilon = 0.1 | ||
self.vf_loss_coeff = 0.5 | ||
self.entropy_coeff = 0.01 | ||
self.entropy_coeff_schedule = None | ||
# __sphinx_doc_end__ | ||
# fmt: on | ||
|
||
@override(impala.ImpalaConfig) | ||
def training( | ||
self, | ||
*, | ||
vtrace: Optional[bool] = None, | ||
use_critic: Optional[bool] = None, | ||
use_gae: Optional[bool] = None, | ||
lambda_: Optional[float] = None, | ||
clip_param: Optional[float] = None, | ||
use_kl_loss: Optional[bool] = None, | ||
kl_coeff: Optional[float] = None, | ||
kl_target: Optional[float] = None, | ||
**kwargs, | ||
) -> "APPOConfig": | ||
"""Sets the training related configuration. | ||
|
||
Args: | ||
vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE | ||
advantages will be used instead. | ||
use_critic: Should use a critic as a baseline (otherwise don't use value | ||
baseline; required for using GAE). Only applies if vtrace=False. | ||
use_gae: If true, use the Generalized Advantage Estimator (GAE) | ||
with a value function, see https://arxiv.org/pdf/1506.02438.pdf. | ||
Only applies if vtrace=False. | ||
lambda_: GAE (lambda) parameter. | ||
clip_param: PPO surrogate slipping parameter. | ||
use_kl_loss: Whether to use the KL-term in the loss function. | ||
kl_coeff: Coefficient for weighting the KL-loss term. | ||
kl_target: Target term for the KL-term to reach (via adjusting the | ||
`kl_coeff` automatically). | ||
|
||
Returns: | ||
This updated TrainerConfig object. | ||
""" | ||
# Pass kwargs onto super's `training()` method. | ||
super().training(**kwargs) | ||
|
||
if vtrace is not None: | ||
self.vtrace = vtrace | ||
if use_critic is not None: | ||
self.use_critic = use_critic | ||
if use_gae is not None: | ||
self.use_gae = use_gae | ||
if lambda_ is not None: | ||
self.lambda_ = lambda_ | ||
if clip_param is not None: | ||
self.clip_param = clip_param | ||
if use_kl_loss is not None: | ||
self.use_kl_loss = use_kl_loss | ||
if kl_coeff is not None: | ||
self.kl_coeff = kl_coeff | ||
if kl_target is not None: | ||
self.kl_target = kl_target | ||
|
||
return self | ||
|
||
|
||
class UpdateTargetAndKL: | ||
|
@@ -130,7 +201,7 @@ def __init__(self, config, *args, **kwargs): | |
@classmethod | ||
@override(Trainer) | ||
def get_default_config(cls) -> TrainerConfigDict: | ||
return DEFAULT_CONFIG | ||
return APPOConfig().to_dict() | ||
|
||
@override(Trainer) | ||
def get_default_policy_class( | ||
|
@@ -142,3 +213,20 @@ def get_default_policy_class( | |
return AsyncPPOTorchPolicy | ||
else: | ||
return AsyncPPOTFPolicy | ||
|
||
|
||
# Deprecated: Use ray.rllib.agents.a3c.A3CConfig instead! | ||
class _deprecated_default_config(dict): | ||
def __init__(self): | ||
super().__init__(APPOConfig().to_dict()) | ||
|
||
@Deprecated( | ||
old="ray.rllib.agents.ppo.appo.DEFAULT_CONFIG", | ||
new="ray.rllib.agents.ppo.appo.APPOConfig(...)", | ||
error=False, | ||
) | ||
def __getitem__(self, item): | ||
return super().__getitem__(item) | ||
|
||
|
||
DEFAULT_CONFIG = _deprecated_default_config() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain this? Can we remove it altogether?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry, this shouldn't be here. Will remove ...