From 2b34011fe60a132d922e69de0121937b27d6c2e8 Mon Sep 17 00:00:00 2001 From: avnishn Date: Mon, 6 Dec 2021 11:04:11 -0800 Subject: [PATCH 01/14] Solve base_env merge conflicts --- rllib/env/base_env.py | 30 ++++++++++++++++++++++++++++-- rllib/env/vector_env.py | 22 +++++++++++++++++++--- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 1717694b23b8..2abe587e60e5 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -185,12 +185,18 @@ def try_reset(self, env_id: Optional[EnvID] = None return None @PublicAPI - def get_sub_environments(self) -> List[EnvType]: + def get_sub_environments( + self, as_dict: bool = False) -> Union[List[EnvType], dict]: """Return a reference to the underlying sub environments, if any. + Args: + as_dict: If True, return a dict mapping from env_id to env. + Returns: - List of the underlying sub environments or []. + List or dictionary of the underlying sub environments or [] / {}. """ + if as_dict: + return {} return [] @PublicAPI @@ -218,6 +224,26 @@ def stop(self) -> None: def get_unwrapped(self) -> List[EnvType]: return self.get_sub_environments() + @PublicAPI + @property + def observation_space(self) -> gym.Space: + """Returns the observation space for each environment. + + Returns: + The observation space for each environment. + """ + raise NotImplementedError + + @PublicAPI + @property + def action_space(self) -> gym.Space: + """Returns the action space for each environment. + + Returns: + The observation space for each environment. + """ + raise NotImplementedError + # Fixed agent identifier when there is only the single agent in the env _DUMMY_AGENT_ID = "agent0" diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index f815d82fb834..ad52bb7518c5 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -1,7 +1,7 @@ import logging import gym import numpy as np -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union from ray.rllib.env.base_env import BaseEnv from ray.rllib.utils.annotations import Deprecated, override, PublicAPI @@ -312,8 +312,24 @@ def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict: } @override(BaseEnv) - def get_sub_environments(self) -> List[EnvType]: - return self.vector_env.get_sub_environments() + def get_sub_environments( + self, as_dict: bool = False) -> Union[List[EnvType], dict]: + """Return a reference to the underlying sub environments, if any. + + Args: + as_dict: If True, return a dict mapping from env_id to env. + + Returns: + List or dictionary of the underlying sub environments or [] / {}. + """ + if not as_dict: + return self.vector_env.get_sub_environments() + else: + return { + _id: env + for _id, env in enumerate( + self.vector_env.get_sub_environments()) + } @override(BaseEnv) def try_render(self, env_id: Optional[EnvID] = None) -> None: From f94b916f5d2938e19166e3d13e369666287a52c0 Mon Sep 17 00:00:00 2001 From: avnishn Date: Mon, 6 Dec 2021 11:04:44 -0800 Subject: [PATCH 02/14] Solve base_env merge conflicts --- rllib/env/vector_env.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index ad52bb7518c5..764144c6983c 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -335,3 +335,21 @@ def get_sub_environments( def try_render(self, env_id: Optional[EnvID] = None) -> None: assert env_id is None or isinstance(env_id, int) return self.vector_env.try_render_at(env_id) + + @property + def observation_space(self) -> gym.Space: + """Returns the observation space for each environment. + + Returns: + The observation space for each environment. + """ + return self.vector_env.observation_space + + @property + def action_space(self) -> gym.Space: + """Returns the action space for each environment. + + Returns: + The action space for each environment. + """ + return self.vector_env.action_space From 1b750af345dfbcb7fa3d127a67eebe25eba256a3 Mon Sep 17 00:00:00 2001 From: avnishn Date: Wed, 1 Dec 2021 15:20:41 -0800 Subject: [PATCH 03/14] Add support for MultiAgentEnvs --- rllib/env/base_env.py | 6 ++++++ rllib/env/multi_agent_env.py | 20 +++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 2abe587e60e5..710bc2da799d 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -229,6 +229,9 @@ def get_unwrapped(self) -> List[EnvType]: def observation_space(self) -> gym.Space: """Returns the observation space for each environment. + Note: samples from the observation space need to be preprocessed into a + `MultiEnvDict` before being used by a policy. + Returns: The observation space for each environment. """ @@ -239,6 +242,9 @@ def observation_space(self) -> gym.Space: def action_space(self) -> gym.Space: """Returns the action space for each environment. + Note: samples from the action space need to be preprocessed into a + `MultiEnvDict` before being passed to `send_actions`. + Returns: The observation space for each environment. """ diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 0b6d0fb73c22..4e336487503c 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -336,7 +336,12 @@ def try_reset(self, return obs @override(BaseEnv) - def get_sub_environments(self) -> List[EnvType]: + def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]: + if as_dict: + return { + _id: env_state + for _id, env_state in enumerate(self.env_states) + } return [state.env for state in self.env_states] @override(BaseEnv) @@ -346,6 +351,19 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None: assert isinstance(env_id, int) return self.envs[env_id].render() + @property + def observation_space(self) -> gym.Space: + space = { + _id: env.observation_space + for _id, env in enumerate(self.envs) + } + return gym.spaces.Dict(space) + + @property + def action_space(self) -> gym.Space: + space = {_id: env.action_space for _id, env in enumerate(self.envs)} + return gym.spaces.Dict(space) + class _MultiAgentEnvState: def __init__(self, env: MultiAgentEnv): From 0c47d33bd4b5b99ccdedf923b9d974ee16db2b92 Mon Sep 17 00:00:00 2001 From: avnishn Date: Wed, 1 Dec 2021 15:32:43 -0800 Subject: [PATCH 04/14] Add support for external envs --- rllib/env/external_env.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index cbd1bb174ff8..7323f62da377 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -337,11 +337,11 @@ def __init__(self, self.external_env = external_env self.prep = preprocessor self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv) - self.action_space = external_env.action_space + self._action_space = external_env.action_space if preprocessor: - self.observation_space = preprocessor.observation_space + self._observation_space = preprocessor.observation_space else: - self.observation_space = external_env.observation_space + self._observation_space = external_env.observation_space external_env.start() @override(BaseEnv) @@ -413,3 +413,11 @@ def fix(d, zero_val): with_dummy_agent_id(all_dones, "__all__"), \ with_dummy_agent_id(all_infos), \ with_dummy_agent_id(off_policy_actions) + + @property + def observation_space(self) -> gym.Space: + return self._observation_space + + @property + def action_space(self) -> gym.Space: + return self._action_space From 5d3287eef23bf63fa46e8ad082df89f3886b4521 Mon Sep 17 00:00:00 2001 From: avnishn Date: Wed, 1 Dec 2021 17:04:20 -0800 Subject: [PATCH 05/14] Add support for remote base envs --- rllib/env/remote_base_env.py | 42 +++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/rllib/env/remote_base_env.py b/rllib/env/remote_base_env.py index 4e5b538c151d..afdc12feb297 100644 --- a/rllib/env/remote_base_env.py +++ b/rllib/env/remote_base_env.py @@ -1,6 +1,8 @@ import logging from typing import Callable, Dict, List, Optional, Tuple +import gym + import ray from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN from ray.rllib.utils.annotations import override, PublicAPI @@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv): from the remote simulator actors. Both single and multi-agent child envs are supported, and envs can be stepped synchronously or asynchronously. + NOTE: This class implicitly assumes that the remote envs are gym.Env's + You shouldn't need to instantiate this class directly. It's automatically inserted when you use the `remote_worker_envs=True` option in your Trainer's config. @@ -61,6 +65,8 @@ def __init__(self, # List of ray actor handles (each handle points to one @ray.remote # sub-environment). self.actors: Optional[List[ray.actor.ActorHandle]] = None + self._observation_space = None + self._action_space = None # Dict mapping object refs (return values of @ray.remote calls), # whose actual values we are waiting for (via ray.wait in # `self.poll()`) to their corresponding actor handles (the actors @@ -97,6 +103,10 @@ def make_remote_env(i): self.actors = [ make_remote_env(i) for i in range(self.num_envs) ] + self._observation_space = ray.get( + self.actors[0].observation_space.remote()) + self._action_space = ray.get( + self.actors[0].action_space.remote()) # Lazy initialization. Call `reset()` on all @ray.remote # sub-environment actors at the beginning. @@ -199,9 +209,23 @@ def stop(self) -> None: @override(BaseEnv) @PublicAPI - def get_sub_environments(self) -> List[EnvType]: + def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]: + if as_dict: + return {env_id: actor for env_id, actor in enumerate(self.actors)} return self.actors + @property + @override(BaseEnv) + @PublicAPI + def observation_space(self) -> gym.Space: + return self._observation_space + + @property + @override(BaseEnv) + @PublicAPI + def action_space(self) -> gym.Space: + return self._action_space + @ray.remote(num_cpus=0) class _RemoteMultiAgentEnv: @@ -221,6 +245,14 @@ def reset(self): def step(self, action_dict): return self.env.step(action_dict) + # defining these 2 functions that way this information can be queried + # with a call to ray.get() + def observation_space(self): + return self.env.observation_space + + def action_space(self): + return self.env.action_space + @ray.remote(num_cpus=0) class _RemoteSingleAgentEnv: @@ -243,3 +275,11 @@ def step(self, action): } for x in [obs, rew, done, info]] done["__all__"] = done[_DUMMY_AGENT_ID] return obs, rew, done, info + + # defining these 2 functions that way this information can be queried + # with a call to ray.get() + def observation_space(self): + return self.env.observation_space + + def action_space(self): + return self.env.action_space From c4291c55c615fc1ea3b7aab90f170dec814af186 Mon Sep 17 00:00:00 2001 From: avnishn Date: Thu, 2 Dec 2021 15:25:20 -0800 Subject: [PATCH 06/14] Force space types for base envs to be Dict spaces --- rllib/env/base_env.py | 22 +++++++++++++++++++++- rllib/env/external_env.py | 2 +- rllib/env/multi_agent_env.py | 2 +- rllib/env/remote_base_env.py | 2 +- rllib/env/vector_env.py | 28 +++++++--------------------- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 710bc2da799d..9c644fa0ed38 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -226,7 +226,7 @@ def get_unwrapped(self) -> List[EnvType]: @PublicAPI @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gym.spaces.Dict: """Returns the observation space for each environment. Note: samples from the observation space need to be preprocessed into a @@ -250,6 +250,26 @@ def action_space(self) -> gym.Space: """ raise NotImplementedError + def observation_space_contains(self, x: MultiEnvDict) -> bool: + self._space_contains(self.observation_space, x) + + def action_space_contains(self, x: MultiEnvDict) -> bool: + return self._space_contains(self.action_space, x) + + @staticmethod + def _space_contains(space, x: MultiEnvDict) -> bool: + # this removes the agent_id key and inner dicts + # in MultiEnvDicts + flattened_obs = { + env_id: list(obs.values()) + for env_id, obs in x.items() + } + ret = True + for env_id in flattened_obs: + for obs in flattened_obs[env_id]: + ret = ret and space[env_id].contains(obs) + return ret + # Fixed agent identifier when there is only the single agent in the env _DUMMY_AGENT_ID = "agent0" diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index 7323f62da377..c13ddd3c883e 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -415,7 +415,7 @@ def fix(d, zero_val): with_dummy_agent_id(off_policy_actions) @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gym.spaces.Dict: return self._observation_space @property diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 4e336487503c..32cdd8bfad4d 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -352,7 +352,7 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None: return self.envs[env_id].render() @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gym.spaces.Dict: space = { _id: env.observation_space for _id, env in enumerate(self.envs) diff --git a/rllib/env/remote_base_env.py b/rllib/env/remote_base_env.py index afdc12feb297..1ae5b0a6802c 100644 --- a/rllib/env/remote_base_env.py +++ b/rllib/env/remote_base_env.py @@ -217,7 +217,7 @@ def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]: @property @override(BaseEnv) @PublicAPI - def observation_space(self) -> gym.Space: + def observation_space(self) -> gym.spaces.Dict: return self._observation_space @property diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 764144c6983c..7b80c05b7ce4 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -314,14 +314,6 @@ def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict: @override(BaseEnv) def get_sub_environments( self, as_dict: bool = False) -> Union[List[EnvType], dict]: - """Return a reference to the underlying sub environments, if any. - - Args: - as_dict: If True, return a dict mapping from env_id to env. - - Returns: - List or dictionary of the underlying sub environments or [] / {}. - """ if not as_dict: return self.vector_env.get_sub_environments() else: @@ -337,19 +329,13 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None: return self.vector_env.try_render_at(env_id) @property - def observation_space(self) -> gym.Space: - """Returns the observation space for each environment. - - Returns: - The observation space for each environment. - """ - return self.vector_env.observation_space + @override(BaseEnv) + @PublicAPI + def observation_space(self) -> gym.spaces.Dict: + return gym.spaces.Dict({0: self.vector_env.observation_space}) @property + @override(BaseEnv) + @PublicAPI def action_space(self) -> gym.Space: - """Returns the action space for each environment. - - Returns: - The action space for each environment. - """ - return self.vector_env.action_space + return gym.spaces.Dict({0: self.vector_env.action_space}) From 9c7aeaf0f93ca3dd0b0f6b102073408a8334c2a4 Mon Sep 17 00:00:00 2001 From: avnishn Date: Mon, 6 Dec 2021 11:07:18 -0800 Subject: [PATCH 07/14] Add missing import --- rllib/env/base_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 9c644fa0ed38..f5554287b992 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -1,6 +1,7 @@ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\ Union +import gym import ray from ray.rllib.utils.annotations import Deprecated, override, PublicAPI from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \ From eaf9d9d339c5952136773866f2c8d6bd11f0ac94 Mon Sep 17 00:00:00 2001 From: avnishn Date: Mon, 6 Dec 2021 13:22:50 -0800 Subject: [PATCH 08/14] Fix merge error --- rllib/env/vector_env.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 7b80c05b7ce4..56c6de9e7ac3 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -265,14 +265,26 @@ class VectorEnvWrapper(BaseEnv): def __init__(self, vector_env: VectorEnv): self.vector_env = vector_env - self.action_space = vector_env.action_space - self.observation_space = vector_env.observation_space self.num_envs = vector_env.num_envs self.new_obs = None # lazily initialized self.cur_rewards = [None for _ in range(self.num_envs)] self.cur_dones = [False for _ in range(self.num_envs)] self.cur_infos = [None for _ in range(self.num_envs)] + obs_space = { + _id: env.observation_space + for _id, env in enumerate( + self.vector_env.get_sub_environments()) + } + + act_space = { + _id: env.observation_space + for _id, env in enumerate( + self.vector_env.get_sub_environments()) + } + self._observation_space = gym.spaces.Dict(obs_space) + self._action_space = gym.spaces.Dict(act_space) + @override(BaseEnv) def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]: @@ -332,10 +344,10 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None: @override(BaseEnv) @PublicAPI def observation_space(self) -> gym.spaces.Dict: - return gym.spaces.Dict({0: self.vector_env.observation_space}) + return self._observation_space @property @override(BaseEnv) @PublicAPI def action_space(self) -> gym.Space: - return gym.spaces.Dict({0: self.vector_env.action_space}) + return self._action_space From f2464113c182e2e18fd5c0353931aff43c0d880b Mon Sep 17 00:00:00 2001 From: avnishn Date: Mon, 6 Dec 2021 13:23:13 -0800 Subject: [PATCH 09/14] fix lint --- rllib/env/vector_env.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 56c6de9e7ac3..7d8b7ba82ebb 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -273,14 +273,12 @@ def __init__(self, vector_env: VectorEnv): obs_space = { _id: env.observation_space - for _id, env in enumerate( - self.vector_env.get_sub_environments()) + for _id, env in enumerate(self.vector_env.get_sub_environments()) } act_space = { _id: env.observation_space - for _id, env in enumerate( - self.vector_env.get_sub_environments()) + for _id, env in enumerate(self.vector_env.get_sub_environments()) } self._observation_space = gym.spaces.Dict(obs_space) self._action_space = gym.spaces.Dict(act_space) From a60d4863b9119dc7e97eef39c6d09045161a3341 Mon Sep 17 00:00:00 2001 From: avnishn Date: Mon, 6 Dec 2021 14:45:29 -0800 Subject: [PATCH 10/14] Fix bug in vector_env action space --- rllib/env/multi_agent_env.py | 4 ++++ rllib/env/vector_env.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 32cdd8bfad4d..a578388dc8fb 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -352,6 +352,8 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None: return self.envs[env_id].render() @property + @override(BaseEnv) + @PublicAPI def observation_space(self) -> gym.spaces.Dict: space = { _id: env.observation_space @@ -360,6 +362,8 @@ def observation_space(self) -> gym.spaces.Dict: return gym.spaces.Dict(space) @property + @override(BaseEnv) + @PublicAPI def action_space(self) -> gym.Space: space = {_id: env.action_space for _id, env in enumerate(self.envs)} return gym.spaces.Dict(space) diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 7d8b7ba82ebb..0f34d99469c1 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -277,7 +277,7 @@ def __init__(self, vector_env: VectorEnv): } act_space = { - _id: env.observation_space + _id: env.action_space for _id, env in enumerate(self.vector_env.get_sub_environments()) } self._observation_space = gym.spaces.Dict(obs_space) From bd6980776398daf1b2acbc34a3161f2ea6ce59d0 Mon Sep 17 00:00:00 2001 From: avnishn Date: Tue, 7 Dec 2021 10:33:44 -0800 Subject: [PATCH 11/14] Change vector env spaces to return sub_env space --- rllib/env/base_env.py | 2 +- rllib/env/vector_env.py | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index f5554287b992..a2d6e6c6e689 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -258,7 +258,7 @@ def action_space_contains(self, x: MultiEnvDict) -> bool: return self._space_contains(self.action_space, x) @staticmethod - def _space_contains(space, x: MultiEnvDict) -> bool: + def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: # this removes the agent_id key and inner dicts # in MultiEnvDicts flattened_obs = { diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 0f34d99469c1..d4acb5e1114d 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -270,18 +270,8 @@ def __init__(self, vector_env: VectorEnv): self.cur_rewards = [None for _ in range(self.num_envs)] self.cur_dones = [False for _ in range(self.num_envs)] self.cur_infos = [None for _ in range(self.num_envs)] - - obs_space = { - _id: env.observation_space - for _id, env in enumerate(self.vector_env.get_sub_environments()) - } - - act_space = { - _id: env.action_space - for _id, env in enumerate(self.vector_env.get_sub_environments()) - } - self._observation_space = gym.spaces.Dict(obs_space) - self._action_space = gym.spaces.Dict(act_space) + self._observation_space = vector_env.observation_space + self._action_space = vector_env.action_space @override(BaseEnv) def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, @@ -349,3 +339,17 @@ def observation_space(self) -> gym.spaces.Dict: @PublicAPI def action_space(self) -> gym.Space: return self._action_space + + @staticmethod + def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: + """Check if + + Note: With vector envs, we can process the raw observations + and ignore the agent ids and env ids, since vector envs' + sub environements are guaranteed to be the same + """ + for _, multi_agent_dict in x.items(): + for _, element in multi_agent_dict.items(): + if not space.contains(element): + return False + return True From f85858552c61251c1b534cd1629268c73ec71d27 Mon Sep 17 00:00:00 2001 From: avnishn Date: Wed, 8 Dec 2021 10:06:26 -0800 Subject: [PATCH 12/14] Update docstrings and fix lint --- rllib/env/base_env.py | 9 +++++++++ rllib/env/external_env.py | 4 ++++ rllib/env/vector_env.py | 9 ++++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index a2d6e6c6e689..fc469f77ded2 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -259,6 +259,15 @@ def action_space_contains(self, x: MultiEnvDict) -> bool: @staticmethod def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: + """Check if the given space contains the observations of x. + + Args: + space: The space to if x's observations are contained in. + x: The observations to check. + + Returns: + True if the observations of x are contained in space. + """ # this removes the agent_id key and inner dicts # in MultiEnvDicts flattened_obs = { diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index c13ddd3c883e..302f82a43549 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -415,9 +415,13 @@ def fix(d, zero_val): with_dummy_agent_id(off_policy_actions) @property + @override(BaseEnv) + @PublicAPI def observation_space(self) -> gym.spaces.Dict: return self._observation_space @property + @override(BaseEnv) + @PublicAPI def action_space(self) -> gym.Space: return self._action_space diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index d4acb5e1114d..20018ad0076a 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -342,11 +342,18 @@ def action_space(self) -> gym.Space: @staticmethod def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool: - """Check if + """Check if the given space contains the observations of x. + + Args: + space: The space to if x's observations are contained in. + x: The observations to check. Note: With vector envs, we can process the raw observations and ignore the agent ids and env ids, since vector envs' sub environements are guaranteed to be the same + + Returns: + True if the observations of x are contained in space. """ for _, multi_agent_dict in x.items(): for _, element in multi_agent_dict.items(): From 94df1d47277280d55f94947a95666f008d02f7d7 Mon Sep 17 00:00:00 2001 From: avnishn Date: Wed, 8 Dec 2021 13:26:06 -0800 Subject: [PATCH 13/14] Mark long running sac tests as flakey --- rllib/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 541bde6750eb..2efac3eef879 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -290,7 +290,7 @@ py_test( name = "learning_cartpole_simpleq_fake_gpus", main = "tests/run_regression_tests.py", tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"], - size = "large", + size = "medium", srcs = ["tests/run_regression_tests.py"], data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"], args = ["--yaml-dir=tuned_examples/dqn"] @@ -468,7 +468,7 @@ py_test( py_test( name = "learning_tests_transformed_actions_pendulum_sac", main = "tests/run_regression_tests.py", - tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"], + tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"], size = "large", srcs = ["tests/run_regression_tests.py"], data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"], @@ -478,7 +478,7 @@ py_test( py_test( name = "learning_pendulum_sac_fake_gpus", main = "tests/run_regression_tests.py", - tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"], + tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"], size = "large", srcs = ["tests/run_regression_tests.py"], data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"], From 89f183fb5988cf94cc35584fb1e1b493e1415cfb Mon Sep 17 00:00:00 2001 From: avnishn Date: Wed, 8 Dec 2021 15:21:19 -0800 Subject: [PATCH 14/14] Remove unecessary documentation in order to pass lint --- rllib/env/base_env.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index fc469f77ded2..03b986d6daea 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -32,12 +32,6 @@ class BaseEnv: rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv rllib.ExternalEnv => rllib.BaseEnv - Attributes: - action_space (gym.Space): Action space. This must be defined for - single-agent envs. Multi-agent envs can set this to None. - observation_space (gym.Space): Observation space. This must be defined - for single-agent envs. Multi-agent envs can set this to None. - Examples: >>> env = MyBaseEnv() >>> obs, rewards, dones, infos, off_policy_actions = env.poll()