Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] MAML config objects. #25066

Merged
merged 3 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rllib/algorithms/maml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.maml.maml import MAMLConfig, MAMLTrainer, DEFAULT_CONFIG

__all__ = [
"MAMLConfig",
"MAMLTrainer",
"DEFAULT_CONFIG",
]
227 changes: 166 additions & 61 deletions rllib/algorithms/maml/maml.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import numpy as np
from typing import Type
from typing import Optional, Type

from ray.rllib.utils.sgd import standardized
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
Expand All @@ -18,70 +17,159 @@
from ray.rllib.execution.metric_ops import CollectMetrics
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.sgd import standardized
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import from_actors, LocalIterator

logger = logging.getLogger(__name__)

# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# GAE(lambda) parameter
"lambda": 1.0,
# Initial coefficient for KL divergence
"kl_coeff": 0.0005,
# Size of batches collected from each worker
"rollout_fragment_length": 200,
# Do create an actual env on the local worker (worker-idx=0).
"create_env_on_driver": True,
# Stepsize of SGD
"lr": 1e-3,
"model": {

class MAMLConfig(TrainerConfig):
"""Defines a configuration class from which a MAMLTrainer can be built.

Example:
>>> from ray.rllib.algorithms.maml import MAMLConfig
>>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()

Example:
>>> from ray.rllib.algorithms.maml import MAMLConfig
>>> from ray import tune
>>> config = MAMLConfig()
>>> # Print out some default values.
>>> print(config.lr)
... 0.0004
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ellipses are for continuation of a previous line. I think you need to remove them (and obv. keep the indent) for this to be doc-tested correctly.

Not sure why this doesn't fail, do we turn doctest off somewhere for these examples? Last time I checked it was still part of the docs build.

sven1977 marked this conversation as resolved.
Show resolved Hide resolved
>>> # Update the config object.
>>> config.training(grad_clip=tune.grid_search([10.0, 40.0]))
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "MAML",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""

def __init__(self, trainer_class=None):
"""Initializes a PGConfig instance."""
super().__init__(trainer_class=trainer_class or MAMLTrainer)

# fmt: off
# __sphinx_doc_begin__
# MAML-specific config settings.
self.use_gae = True
self.lambda_ = 1.0
self.kl_coeff = 0.0005
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.0
self.clip_param = 0.3
self.vf_clip_param = 10.0
self.grad_clip = None
self.kl_target = 0.01
self.inner_adaptation_steps = 1
self.maml_optimizer_steps = 5
self.inner_lr = 0.1
self.use_meta_env = True

# Override some of TrainerConfig's default values with MAML-specific values.
self.rollout_fragment_length = 200
self.create_env_on_local_worker = True
self.lr = 1e-3

# Share layers for value function.
"vf_share_layers": False,
},
# Coefficient of the value function loss
"vf_loss_coeff": 0.5,
# Coefficient of the entropy regularizer
"entropy_coeff": 0.0,
# PPO clip parameter
"clip_param": 0.3,
# Clip param for the value function. Note that this is sensitive to the
# scale of the rewards. If your expected V is large, increase this.
"vf_clip_param": 10.0,
# If specified, clip the global norm of gradients by this amount
"grad_clip": None,
# Target value for KL divergence
"kl_target": 0.01,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# Number of Inner adaptation steps for the MAML algorithm
"inner_adaptation_steps": 1,
# Number of MAML steps per meta-update iteration (PPO steps)
"maml_optimizer_steps": 5,
# Inner Adaptation Step size
"inner_lr": 0.1,
# Use Meta Env Template
"use_meta_env": True,

# Deprecated keys:
# Share layers for value function. If you set this to True, it's important
# to tune vf_loss_coeff.
# Use config.model.vf_share_layers instead.
"vf_share_layers": DEPRECATED_VALUE,

# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on
self.model.update({
"vf_share_layers": False,
})

self.batch_mode = "complete_episodes"
self._disable_execution_plan_api = False
# __sphinx_doc_end__
# fmt: on

# Deprecated keys:
self.vf_share_layers = DEPRECATED_VALUE

def training(
self,
*,
use_gae: Optional[bool] = None,
lambda_: Optional[float] = None,
kl_coeff: Optional[float] = None,
vf_loss_coeff: Optional[float] = None,
entropy_coeff: Optional[float] = None,
clip_param: Optional[float] = None,
vf_clip_param: Optional[float] = None,
grad_clip: Optional[float] = None,
kl_target: Optional[float] = None,
inner_adaptation_steps: Optional[int] = None,
maml_optimizer_steps: Optional[int] = None,
inner_lr: Optional[float] = None,
use_meta_env: Optional[bool] = None,
**kwargs,
) -> "MAMLConfig":
"""Sets the training related configuration.

Args:
use_gae: If true, use the Generalized Advantage Estimator (GAE)
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
lambda_: The GAE (lambda) parameter.
kl_coeff: Initial coefficient for KL divergence.
vf_loss_coeff: Coefficient of the value function loss.
entropy_coeff: Coefficient of the entropy regularizer.
clip_param: PPO clip parameter.
vf_clip_param: Clip param for the value function. Note that this is
sensitive to the scale of the rewards. If your expected V is large,
increase this.
grad_clip: If specified, clip the global norm of gradients by this amount.
kl_target: Target value for KL divergence.
inner_adaptation_steps: Number of Inner adaptation steps for the MAML
algorithm.
maml_optimizer_steps: Number of MAML steps per meta-update iteration
(PPO steps).
inner_lr: Inner Adaptation Step size.
use_meta_env: Use Meta Env Template.

Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)

if use_gae is not None:
self.use_gae = use_gae
if lambda_ is not None:
self.lambda_ = lambda_
if kl_coeff is not None:
self.kl_coeff = kl_coeff
if vf_loss_coeff is not None:
self.vf_loss_coeff = vf_loss_coeff
if entropy_coeff is not None:
self.entropy_coeff = entropy_coeff
if clip_param is not None:
self.clip_param = clip_param
if vf_clip_param is not None:
self.vf_clip_param = vf_clip_param
if grad_clip is not None:
self.grad_clip = grad_clip
if kl_target is not None:
self.kl_target = kl_target
if inner_adaptation_steps is not None:
self.inner_adaptation_steps = inner_adaptation_steps
if maml_optimizer_steps is not None:
self.maml_optimizer_steps = maml_optimizer_steps
if inner_lr is not None:
self.inner_lr = inner_lr
if use_meta_env is not None:
self.use_meta_env = use_meta_env

return self


# @mluo: TODO
Expand Down Expand Up @@ -169,7 +257,7 @@ class MAMLTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return MAMLConfig().to_dict()

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
Expand Down Expand Up @@ -281,3 +369,20 @@ def inner_adaptation_steps(itr):
)
)
return train_op


# Deprecated: Use ray.rllib.algorithms.qmix.qmix.QMixConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(MAMLConfig().to_dict())

@Deprecated(
old="ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG",
new="ray.rllib.algorithms.maml.maml.MAMLConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)


DEFAULT_CONFIG = _deprecated_default_config()
7 changes: 3 additions & 4 deletions rllib/algorithms/maml/tests/test_maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ def tearDownClass(cls):

def test_maml_compilation(self):
"""Test whether a MAMLTrainer can be built with all frameworks."""
config = maml.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["horizon"] = 200
config = maml.MAMLConfig().rollouts(num_rollout_workers=1, horizon=200)

num_iterations = 1

# Test for tf framework (torch not implemented yet).
Expand All @@ -35,7 +34,7 @@ def test_maml_compilation(self):
continue
print("env={}".format(env))
env_ = "ray.rllib.examples.env.{}".format(env)
trainer = maml.MAMLTrainer(config=config, env=env_)
trainer = config.build(env=env_)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
Expand Down