Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add more detailed options to Algorithm.add_module/remove_module. #46836

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
516 changes: 291 additions & 225 deletions rllib/algorithms/algorithm.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.env_context import EnvContext
Expand All @@ -44,7 +45,6 @@
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import NotProvided, from_config
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.serialization import (
NOT_SERIALIZABLE,
Expand Down Expand Up @@ -2655,7 +2655,7 @@ def multi_agent(
# Make sure our Policy IDs are ok (this should work whether `policies`
# is a dict or just any Sequence).
for pid in policies:
validate_policy_id(pid, error=True)
validate_module_id(pid, error=True)

# Collection: Convert to dict.
if isinstance(policies, (set, tuple, list)):
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_worker_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
algorithm.add_module(
module_id="test_module",
module_spec=spec,
evaluation_workers=True,
add_to_eval_env_runners=True,
)


Expand Down
8 changes: 4 additions & 4 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LearnerConnectorPipeline,
)
from ray.rllib.core import COMPONENT_OPTIMIZER, COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.marl_module import (
MultiAgentRLModule,
MultiAgentRLModuleSpec,
Expand Down Expand Up @@ -58,7 +59,6 @@
MiniBatchCyclicIterator,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import (
EpisodeType,
Expand Down Expand Up @@ -718,9 +718,9 @@ def add_module(
returns False) will not be updated.

Returns:
The new MultiAgentRLModuleSpec (after the change has been performed).
The new MultiAgentRLModuleSpec (after the RLModule has been added).
"""
validate_policy_id(module_id, error=True)
validate_module_id(module_id, error=True)
self._check_is_built()

# Force-set inference-only = False.
Expand Down Expand Up @@ -771,7 +771,7 @@ def remove_module(
returns False) will not be updated.

Returns:
The new MultiAgentRLModuleSpec (after the change has been performed).
The new MultiAgentRLModuleSpec (after the RLModule has been removed).
"""
self._check_is_built()
module = self.module[module_id]
Expand Down
30 changes: 24 additions & 6 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict, Counter
import copy
from functools import partial
import itertools
from typing import (
Any,
Callable,
Expand All @@ -21,6 +22,7 @@
from ray import ObjectRef
from ray.rllib.core import COMPONENT_LEARNER, COMPONENT_RL_MODULE
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
Expand All @@ -44,7 +46,6 @@
ShardEpisodesIterator,
ShardObjectRefIterator,
)
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.typing import (
EpisodeType,
ModuleID,
Expand Down Expand Up @@ -695,7 +696,7 @@ def add_module(
Returns:
The new MultiAgentRLModuleSpec (after the change has been performed).
"""
validate_policy_id(module_id, error=True)
validate_module_id(module_id, error=True)

# Force-set inference-only = False.
module_spec = copy.deepcopy(module_spec)
Expand Down Expand Up @@ -815,16 +816,33 @@ def set_state(self, state: StateDict) -> None:
lambda _learner, _ref=state_ref: _learner.set_state(ray.get(_ref))
)

def get_weights(self) -> StateDict:
def get_weights(
self, module_ids: Optional[Collection[ModuleID]] = None
) -> StateDict:
"""Convenience method instead of self.get_state(components=...).

Args:
module_ids: An optional collection of ModuleIDs for which to return weights.
If None (default), return weights of all RLModules.

Returns:
The results of
`self.get_state(components='learner/rl_module')['learner']['rl_module']`.
"""
return self.get_state(components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE)[
COMPONENT_LEARNER
][COMPONENT_RL_MODULE]
# Return the entire RLModule state (all possible single-agent RLModules).
if module_ids is None:
components = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
# Return a subset of the single-agent RLModules.
else:
components = [
"".join(tup)
for tup in itertools.product(
[COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/"],
list(module_ids),
)
]

return self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]

def set_weights(self, weights) -> None:
"""Convenience method instead of self.set_state({'learner': {'rl_module': ..}}).
Expand Down
40 changes: 40 additions & 0 deletions rllib/core/rl_module/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,53 @@
import logging
import re

from ray.rllib.core.rl_module.marl_module import (
MultiAgentRLModule,
MultiAgentRLModuleSpec,
)
from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec
from ray.util import log_once
from ray.util.annotations import PublicAPI

logger = logging.getLogger("ray.rllib")


@PublicAPI(stability="alpha")
def validate_module_id(policy_id: str, error: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!!

"""Makes sure the given `policy_id` is valid.

Args:
policy_id: The Policy ID to check.
IMPORTANT: Must not contain characters that
are also not allowed in Unix/Win filesystems, such as: `<>:"/\\|?*`
or a dot `.` or space ` ` at the end of the ID.
error: Whether to raise an error (ValueError) or a warning in case of an
invalid `policy_id`.

Raises:
ValueError: If the given `policy_id` is not a valid one and `error` is True.
"""
if (
not isinstance(policy_id, str)
or len(policy_id) == 0
or re.search('[<>:"/\\\\|?]', policy_id)
or policy_id[-1] in (" ", ".")
):
msg = (
f"PolicyID `{policy_id}` not valid! IDs must be a non-empty string, "
"must not contain characters that are also disallowed file- or directory "
"names on Unix/Windows and must not end with a dot `.` or a space ` `."
)
if error:
raise ValueError(msg)
elif log_once("invalid_policy_id"):
logger.warning(msg)


__all__ = [
"MultiAgentRLModule",
"MultiAgentRLModuleSpec",
"RLModule",
"SingleAgentRLModuleSpec",
"validate_module_id",
]
9 changes: 6 additions & 3 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
OverrideToImplementCustomLogic,
)
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.serialization import serialize_type, deserialize_type
from ray.rllib.utils.typing import ModuleID, StateDict, T
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -159,14 +158,18 @@ def add_module(
Raises:
ValueError: If the module ID already exists and override is False.
Warnings are raised if the module id is not valid according to the
logic of ``validate_policy_id()``.
logic of ``validate_module_id()``.
"""
validate_policy_id(module_id)
from ray.rllib.core.rl_module import validate_module_id

validate_module_id(module_id)

if module_id in self._rl_modules and not override:
raise ValueError(
f"Module ID {module_id} already exists. If your intention is to "
"override, set override=True."
)

# Set our own inference_only flag to False as soon as any added Module
# has `inference_only=False`.
if not module.config.inference_only:
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/env_runner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
COMPONENT_RL_MODULE,
)
from ray.rllib.core.learner import LearnerGroup
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.utils.actor_manager import RemoteCallResults
Expand All @@ -44,7 +45,6 @@
)
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.typing import (
AgentID,
EnvCreator,
Expand Down Expand Up @@ -720,7 +720,7 @@ def add_policy(
"Only one of `policy_cls` or `policy` must be provided to "
"staticmethod: `EnvRunnerGroup.add_policy()`!"
)
validate_policy_id(policy_id, error=False)
validate_module_id(policy_id, error=False)

# Policy instance not provided: Use the information given here.
if policy_cls is not None:
Expand Down
7 changes: 4 additions & 3 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
create_connectors_for_policy,
maybe_get_filters_for_syncing,
)
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.env.env_context import EnvContext
Expand Down Expand Up @@ -72,7 +73,7 @@
from ray.rllib.utils.filter import Filter, NoFilter, get_filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.policy import create_policy_for_framework, validate_policy_id
from ray.rllib.utils.policy import create_policy_for_framework
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.tf_run_builder import _TFRunBuilder
from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
Expand Down Expand Up @@ -1110,7 +1111,7 @@ def add_policy(
KeyError: If the given `policy_id` already exists in this worker's
PolicyMap.
"""
validate_policy_id(policy_id, error=False)
validate_module_id(policy_id, error=False)

if module_spec is not None and not self.config.enable_rl_module_and_learner:
raise ValueError(
Expand Down Expand Up @@ -1415,7 +1416,7 @@ def set_state(self, state: dict) -> None:
for pid, policy_state in policy_states.items():
# If - for some reason - we have an invalid PolicyID in the state,
# this might be from an older checkpoint (pre v1.0). Just warn here.
validate_policy_id(pid, error=False)
validate_module_id(pid, error=False)

if pid not in self.policy_map:
spec = policy_state.get("policy_spec", None)
Expand Down
4 changes: 2 additions & 2 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def convert_to_msgpack_checkpoint(
this is the same as `msgpack_checkpoint_dir`.
"""
from ray.rllib.algorithms import Algorithm
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.core.rl_module import validate_module_id

# Try to import msgpack and msgpack_numpy.
msgpack = try_import_msgpack(error=True)
Expand Down Expand Up @@ -768,7 +768,7 @@ def convert_to_msgpack_checkpoint(
# Write individual policies to disk, each in their own subdirectory.
for pid, policy_state in policy_states.items():
# From here on, disallow policyIDs that would not work as directory names.
validate_policy_id(pid, error=True)
validate_module_id(pid, error=True)
policy_dir = os.path.join(msgpack_checkpoint_dir, "policies", pid)
os.makedirs(policy_dir, exist_ok=True)
policy = algo.get_policy(pid)
Expand Down
38 changes: 5 additions & 33 deletions rllib/utils/policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import gymnasium as gym
import logging
import numpy as np
import re
from typing import (
Callable,
Dict,
Expand All @@ -16,6 +15,7 @@


import ray.cloudpickle as pickle
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.models.preprocessors import ATARI_OBS_SHAPE
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import SampleBatch
Expand All @@ -41,38 +41,6 @@
tf1, tf, tfv = try_import_tf()


@PublicAPI(stability="alpha")
def validate_policy_id(policy_id: str, error: bool = False) -> None:
"""Makes sure the given `policy_id` is valid.

Args:
policy_id: The Policy ID to check.
IMPORTANT: Must not contain characters that
are also not allowed in Unix/Win filesystems, such as: `<>:"/\\|?*`
or a dot `.` or space ` ` at the end of the ID.
error: Whether to raise an error (ValueError) or a warning in case of an
invalid `policy_id`.

Raises:
ValueError: If the given `policy_id` is not a valid one and `error` is True.
"""
if (
not isinstance(policy_id, str)
or len(policy_id) == 0
or re.search('[<>:"/\\\\|?]', policy_id)
or policy_id[-1] in (" ", ".")
):
msg = (
f"PolicyID `{policy_id}` not valid! IDs must be a non-empty string, "
"must not contain characters that are also disallowed file- or directory "
"names on Unix/Windows and must not end with a dot `.` or a space ` `."
)
if error:
raise ValueError(msg)
elif log_once("invalid_policy_id"):
logger.warning(msg)


@PublicAPI
def create_policy_for_framework(
policy_id: str,
Expand Down Expand Up @@ -333,3 +301,7 @@ def __check_atari_obs_space(obs):
"ray.rllib.env.wrappers.atari_wrappers.wrap_deepmind to wrap "
"you environment."
)


# @OldAPIStack
validate_policy_id = validate_module_id
Loading