Skip to content

Commit

Permalink
[RLlib] New ConnectorV2 API #1: Some preparatory cleanups and fixes. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Nov 17, 2023
1 parent 0e2a523 commit ca29fec
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 88 deletions.
5 changes: 4 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,10 @@ def compute_actions(
filtered_obs, filtered_state = [], []
for agent_id, ob in observations.items():
worker = self.workers.local_worker()
preprocessed = worker.preprocessors[policy_id].transform(ob)
if worker.preprocessors.get(policy_id) is not None:
preprocessed = worker.preprocessors[policy_id].transform(ob)
else:
preprocessed = ob
filtered = worker.filters[policy_id](preprocessed, update=False)
filtered_obs.append(filtered)
if state is None:
Expand Down
33 changes: 24 additions & 9 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,26 +319,37 @@ def __init__(self, algo_class=None):
# If not specified, we will try to auto-detect this.
self._is_atari = None

# TODO (sven): Rename this method into `AlgorithmConfig.sampling()`
# `self.rollouts()`
self.env_runner_cls = None
# TODO (sven): Rename into `num_env_runner_workers`.
self.num_rollout_workers = 0
self.num_envs_per_worker = 1
self.sample_collector = SimpleListCollector
self.create_env_on_local_worker = False
self.sample_async = False
self.enable_connectors = True
self.update_worker_filter_stats = True
self.use_worker_filter_stats = True
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
# and `sample_duration_unit` (replacing batch_mode), like we do it
# in the evaluation config).
self.rollout_fragment_length = 200
# TODO (sven): Rename into `sample_mode`.
self.batch_mode = "truncate_episodes"
# TODO (sven): Rename into `validate_env_runner_workers_after_construction`.
self.validate_workers_after_construction = True
self.compress_observations = False
# TODO (sven): Rename into `env_runner_perf_stats_ema_coef`.
self.sampler_perf_stats_ema_coef = None

# TODO (sven): Deprecate together with old API stack.
self.sample_async = False
self.remote_worker_envs = False
self.remote_env_batch_wait_ms = 0
self.validate_workers_after_construction = True
self.enable_tf1_exec_eagerly = False
self.sample_collector = SimpleListCollector
self.preprocessor_pref = "deepmind"
self.observation_filter = "NoFilter"
self.compress_observations = False
self.enable_tf1_exec_eagerly = False
self.sampler_perf_stats_ema_coef = None
self.update_worker_filter_stats = True
self.use_worker_filter_stats = True
# TODO (sven): End: deprecate.

# `self.training()`
self.gamma = 0.99
Expand Down Expand Up @@ -890,7 +901,7 @@ def validate(self) -> None:
error=True,
)

# RLModule API only works with connectors and with Learner API.
# New API stack (RLModule, Learner APIs) only works with connectors.
if not self.enable_connectors and self._enable_new_api_stack:
raise ValueError(
"The new API stack (RLModule and Learner APIs) only works with "
Expand Down Expand Up @@ -938,6 +949,8 @@ def validate(self) -> None:
"https://github.com/ray-project/ray/issues/35409 for more details."
)

# TODO (sven): Remove this hack. We should not have env-var dependent logic
# in the codebase.
if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)):
# Enable RLModule API and connectors if env variable is set
# (to be used in unittesting)
Expand Down Expand Up @@ -1765,6 +1778,8 @@ def training(
dashboard. If you're seeing that the object store is filling up,
turn down the number of remote requests in flight, or enable compression
in your experiment of timesteps.
learner_class: The `Learner` class to use for (distributed) updating of the
RLModule. Only used when `_enable_new_api_stack=True`.
Returns:
This updated AlgorithmConfig object.
Expand Down
2 changes: 2 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
ResultDict,
TensorType,
)
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig
Expand Down Expand Up @@ -226,6 +227,7 @@ def get_hps_for_module(self, module_id: ModuleID) -> "LearnerHyperparameters":
return self


@PublicAPI(stability="alpha")
class Learner:
"""Base class for Learners.
Expand Down
2 changes: 2 additions & 0 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ray.rllib.utils.numpy import convert_to_numpy
from ray.train._internal.backend_executor import BackendExecutor
from ray.tune.utils.file_transfer import sync_dir_between_nodes
from ray.util.annotations import PublicAPI


if TYPE_CHECKING:
Expand Down Expand Up @@ -58,6 +59,7 @@ def _is_module_trainable(module_id: ModuleID, batch: MultiAgentBatch) -> bool:
return True


@PublicAPI(stability="alpha")
class LearnerGroup:
"""Coordinator of Learners.
Expand Down
45 changes: 17 additions & 28 deletions rllib/core/models/torch/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,21 @@ def __init__(self, config: RecurrentEncoderConfig) -> None:
bias=config.use_bias,
)

self._state_in_out_spec = {
"h": TensorSpec(
"b, l, d",
d=self.config.hidden_dim,
l=self.config.num_layers,
framework="torch",
),
"c": TensorSpec(
"b, l, d",
d=self.config.hidden_dim,
l=self.config.num_layers,
framework="torch",
),
}

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
return SpecDict(
Expand All @@ -293,20 +308,7 @@ def get_input_specs(self) -> Optional[Spec]:
SampleBatch.OBS: TensorSpec(
"b, t, d", d=self.config.input_dims[0], framework="torch"
),
STATE_IN: {
"h": TensorSpec(
"b, l, h",
h=self.config.hidden_dim,
l=self.config.num_layers,
framework="torch",
),
"c": TensorSpec(
"b, l, h",
h=self.config.hidden_dim,
l=self.config.num_layers,
framework="torch",
),
},
STATE_IN: self._state_in_out_spec,
}
)

Expand All @@ -317,20 +319,7 @@ def get_output_specs(self) -> Optional[Spec]:
ENCODER_OUT: TensorSpec(
"b, t, d", d=self.config.output_dims[0], framework="torch"
),
STATE_OUT: {
"h": TensorSpec(
"b, l, h",
h=self.config.hidden_dim,
l=self.config.num_layers,
framework="torch",
),
"c": TensorSpec(
"b, l, h",
h=self.config.hidden_dim,
l=self.config.num_layers,
framework="torch",
),
},
STATE_OUT: self._state_in_out_spec,
}
)

Expand Down
4 changes: 2 additions & 2 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,9 @@ def get_initial_state(self) -> Any:

@OverrideToImplementCustomLogic
def is_stateful(self) -> bool:
"""Returns True if the initial state is empty.
"""Returns False if the initial state is an empty dict (or None).
By default, RLlib assumes that the module is not recurrent if the initial
By default, RLlib assumes that the module is non-recurrent if the initial
state is an empty dict and recurrent otherwise.
This behavior can be overridden by implementing this method.
"""
Expand Down
1 change: 0 additions & 1 deletion rllib/env/env_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc

from typing import Any, Dict, TYPE_CHECKING

from ray.rllib.utils.actor_manager import FaultAwareApply
Expand Down
44 changes: 22 additions & 22 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,75 +212,75 @@ def get_observations(

return self._getattr_by_index("observations", indices, global_ts)

def get_actions(
def get_infos(
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
) -> MultiAgentDict:
"""Gets actions for all agents that stepped in the last timesteps.
"""Gets infos for all agents that stepped in the last timesteps.
Note that actions are only returned for agents that stepped
Note that infos are only returned for agents that stepped
during the given index range.
Args:
indices: Either a single index or a list of indices. The indices
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
This defines the time indices for which the actions
This defines the time indices for which the infos
should be returned.
global_ts: Boolean that defines, if the indices should be considered
environment (`True`) or agent (`False`) steps.
Returns: A dictionary mapping agent ids to actions (of different
Returns: A dictionary mapping agent ids to infos (of different
timesteps). Only for agents that have stepped (were ready) at a
timestep, actions are returned (i.e. not all agent ids are
timestep, infos are returned (i.e. not all agent ids are
necessarily in the keys).
"""
return self._getattr_by_index("infos", indices, global_ts)

return self._getattr_by_index("actions", indices, global_ts)

def get_rewards(
def get_actions(
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
) -> MultiAgentDict:
"""Gets rewards for all agents that stepped in the last timesteps.
"""Gets actions for all agents that stepped in the last timesteps.
Note that rewards are only returned for agents that stepped
Note that actions are only returned for agents that stepped
during the given index range.
Args:
indices: Either a single index or a list of indices. The indices
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
This defines the time indices for which the rewards
This defines the time indices for which the actions
should be returned.
global_ts: Boolean that defines, if the indices should be considered
environment (`True`) or agent (`False`) steps.
Returns: A dictionary mapping agent ids to rewards (of different
Returns: A dictionary mapping agent ids to actions (of different
timesteps). Only for agents that have stepped (were ready) at a
timestep, rewards are returned (i.e. not all agent ids are
timestep, actions are returned (i.e. not all agent ids are
necessarily in the keys).
"""
return self._getattr_by_index("rewards", indices, global_ts)

def get_infos(
return self._getattr_by_index("actions", indices, global_ts)

def get_rewards(
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
) -> MultiAgentDict:
"""Gets infos for all agents that stepped in the last timesteps.
"""Gets rewards for all agents that stepped in the last timesteps.
Note that infos are only returned for agents that stepped
Note that rewards are only returned for agents that stepped
during the given index range.
Args:
indices: Either a single index or a list of indices. The indices
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
This defines the time indices for which the infos
This defines the time indices for which the rewards
should be returned.
global_ts: Boolean that defines, if the indices should be considered
environment (`True`) or agent (`False`) steps.
Returns: A dictionary mapping agent ids to infos (of different
Returns: A dictionary mapping agent ids to rewards (of different
timesteps). Only for agents that have stepped (were ready) at a
timestep, infos are returned (i.e. not all agent ids are
timestep, rewards are returned (i.e. not all agent ids are
necessarily in the keys).
"""
return self._getattr_by_index("infos", indices, global_ts)
return self._getattr_by_index("rewards", indices, global_ts)

def get_extra_model_outputs(
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
Expand Down
25 changes: 7 additions & 18 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

# TODO (sven): This gives a tricky circular import that goes
# deep into the library. We have to see, where to dissolve it.
# deep into the library. We have to see, where to dissolve it.
from ray.rllib.env.single_agent_episode import SingleAgentEpisode

_, tf, _ = try_import_tf()
Expand All @@ -41,6 +41,8 @@ def __init__(self, config: "AlgorithmConfig", **kwargs):
# Get the worker index on which this instance is running.
self.worker_index: int = kwargs.get("worker_index")

# Create the vectorized gymnasium env.

# Register env for the local context.
# Note, `gym.register` has to be called on each worker.
if isinstance(self.config.env, str) and _global_registry.contains(
Expand All @@ -59,7 +61,6 @@ def __init__(self, config: "AlgorithmConfig", **kwargs):
)
gym.register("rllib-single-agent-env-runner-v0", entry_point=entry_point)

# Create the vectorized gymnasium env.
# Wrap into `VectorListInfo`` wrapper to get infos as lists.
self.env: gym.Wrapper = gym.wrappers.VectorListInfo(
gym.vector.make(
Expand All @@ -68,31 +69,19 @@ def __init__(self, config: "AlgorithmConfig", **kwargs):
asynchronous=self.config.remote_worker_envs,
)
)

self.num_envs: int = self.env.num_envs
assert self.num_envs == self.config.num_envs_per_worker

# Create our own instance of the single-agent `RLModule` (which
# Create our own instance of the (single-agent) `RLModule` (which
# the needs to be weight-synched) each iteration.
# TODO (sven, simon): We need to get rid here of the policy_dict,
# but the 'RLModule' takes the 'policy_spec.observation_space'
# from it.
# Below is the non nice solution.
# policy_dict, _ = self.config.get_multi_agent_setup(env=self.env)
module_spec: SingleAgentRLModuleSpec = self.config.get_default_rl_module_spec()
module_spec.observation_space = self.env.envs[0].observation_space
# TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should
# actually hold the spaces for a single env, but for boxes the
# shape is (1, 1) which brings a problem with the action dists.
# shape=(1,) is expected.
# actually hold the spaces for a single env, but for boxes the
# shape is (1, 1) which brings a problem with the action dists.
# shape=(1,) is expected.
module_spec.action_space = self.env.envs[0].action_space
module_spec.model_config_dict = self.config.model

# TODO (sven): By time the `AlgorithmConfig` will get rid of `PolicyDict`
# as well. Then we have to change this function parameter.
# module_spec: MultiAgentRLModuleSpec = self.config.get_marl_module_spec(
# policy_dict=module_dict
# )
self.module: RLModule = module_spec.build()

# This should be the default.
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/tests/test_single_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ray.rllib.env.single_agent_episode import SingleAgentEpisode

# TODO (simon): Add to the tests `info` and `extra_model_outputs`
# as soon as #39732 is merged.
# as soon as #39732 is merged.


class TestEnv(gym.Env):
Expand Down
5 changes: 4 additions & 1 deletion rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from ray.exceptions import RayActorError
from ray.rllib.core.learner import LearnerGroup
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.utils.actor_manager import RemoteCallResults
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.offline import get_dataset_and_shards
from ray.rllib.policy.policy import Policy, PolicyState
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
Expand Down Expand Up @@ -690,6 +690,9 @@ def foreach_worker(
if local_worker and self.local_worker() is not None:
local_result = [func(self.local_worker())]

if not self.__worker_manager.actor_ids():
return local_result

remote_results = self.__worker_manager.foreach_actor(
func,
healthy_only=healthy_only,
Expand Down
Loading

0 comments on commit ca29fec

Please sign in to comment.