Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add necessary fields to Base Envs, and BaseEnv wrapper classes #20832

Merged
merged 14 commits into from
Dec 9, 2021
6 changes: 3 additions & 3 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ py_test(
name = "learning_cartpole_simpleq_fake_gpus",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
size = "large",
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
args = ["--yaml-dir=tuned_examples/dqn"]
Expand Down Expand Up @@ -468,7 +468,7 @@ py_test(
py_test(
name = "learning_tests_transformed_actions_pendulum_sac",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
Expand All @@ -478,7 +478,7 @@ py_test(
py_test(
name = "learning_pendulum_sac_fake_gpus",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],
Expand Down
72 changes: 64 additions & 8 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union

import gym
import ray
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
Expand Down Expand Up @@ -31,12 +32,6 @@ class BaseEnv:
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
rllib.ExternalEnv => rllib.BaseEnv

Attributes:
action_space (gym.Space): Action space. This must be defined for
single-agent envs. Multi-agent envs can set this to None.
observation_space (gym.Space): Observation space. This must be defined
for single-agent envs. Multi-agent envs can set this to None.

Examples:
>>> env = MyBaseEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
Expand Down Expand Up @@ -185,12 +180,18 @@ def try_reset(self, env_id: Optional[EnvID] = None
return None

@PublicAPI
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(
avnishn marked this conversation as resolved.
Show resolved Hide resolved
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
"""Return a reference to the underlying sub environments, if any.

Args:
as_dict: If True, return a dict mapping from env_id to env.

Returns:
List of the underlying sub environments or [].
List or dictionary of the underlying sub environments or [] / {}.
"""
if as_dict:
return {}
return []

@PublicAPI
Expand Down Expand Up @@ -218,6 +219,61 @@ def stop(self) -> None:
def get_unwrapped(self) -> List[EnvType]:
return self.get_sub_environments()

@PublicAPI
@property
def observation_space(self) -> gym.spaces.Dict:
"""Returns the observation space for each environment.

Note: samples from the observation space need to be preprocessed into a
`MultiEnvDict` before being used by a policy.

Returns:
The observation space for each environment.
"""
raise NotImplementedError

@PublicAPI
@property
def action_space(self) -> gym.Space:
"""Returns the action space for each environment.

Note: samples from the action space need to be preprocessed into a
`MultiEnvDict` before being passed to `send_actions`.

Returns:
The observation space for each environment.
"""
raise NotImplementedError

def observation_space_contains(self, x: MultiEnvDict) -> bool:
self._space_contains(self.observation_space, x)

def action_space_contains(self, x: MultiEnvDict) -> bool:
return self._space_contains(self.action_space, x)

@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.

Args:
space: The space to if x's observations are contained in.
x: The observations to check.

Returns:
True if the observations of x are contained in space.
"""
# this removes the agent_id key and inner dicts
# in MultiEnvDicts
flattened_obs = {
env_id: list(obs.values())
for env_id, obs in x.items()
}
ret = True
for env_id in flattened_obs:
for obs in flattened_obs[env_id]:
ret = ret and space[env_id].contains(obs)
return ret


# Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "agent0"
Expand Down
18 changes: 15 additions & 3 deletions rllib/env/external_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,11 @@ def __init__(self,
self.external_env = external_env
self.prep = preprocessor
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
self.action_space = external_env.action_space
self._action_space = external_env.action_space
if preprocessor:
self.observation_space = preprocessor.observation_space
self._observation_space = preprocessor.observation_space
else:
self.observation_space = external_env.observation_space
self._observation_space = external_env.observation_space
external_env.start()

@override(BaseEnv)
Expand Down Expand Up @@ -413,3 +413,15 @@ def fix(d, zero_val):
with_dummy_agent_id(all_dones, "__all__"), \
with_dummy_agent_id(all_infos), \
with_dummy_agent_id(off_policy_actions)

@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space

@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
24 changes: 23 additions & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,12 @@ def try_reset(self,
return obs

@override(BaseEnv)
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {
_id: env_state
for _id, env_state in enumerate(self.env_states)
}
return [state.env for state in self.env_states]

@override(BaseEnv)
Expand All @@ -346,6 +351,23 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
assert isinstance(env_id, int)
return self.envs[env_id].render()

@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
space = {
_id: env.observation_space
for _id, env in enumerate(self.envs)
}
return gym.spaces.Dict(space)

@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
return gym.spaces.Dict(space)


class _MultiAgentEnvState:
def __init__(self, env: MultiAgentEnv):
Expand Down
42 changes: 41 additions & 1 deletion rllib/env/remote_base_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Callable, Dict, List, Optional, Tuple

import gym

import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.annotations import override, PublicAPI
Expand All @@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
from the remote simulator actors. Both single and multi-agent child envs
are supported, and envs can be stepped synchronously or asynchronously.

NOTE: This class implicitly assumes that the remote envs are gym.Env's

You shouldn't need to instantiate this class directly. It's automatically
inserted when you use the `remote_worker_envs=True` option in your
Trainer's config.
Expand Down Expand Up @@ -61,6 +65,8 @@ def __init__(self,
# List of ray actor handles (each handle points to one @ray.remote
# sub-environment).
self.actors: Optional[List[ray.actor.ActorHandle]] = None
self._observation_space = None
self._action_space = None
# Dict mapping object refs (return values of @ray.remote calls),
# whose actual values we are waiting for (via ray.wait in
# `self.poll()`) to their corresponding actor handles (the actors
Expand Down Expand Up @@ -97,6 +103,10 @@ def make_remote_env(i):
self.actors = [
make_remote_env(i) for i in range(self.num_envs)
]
self._observation_space = ray.get(
self.actors[0].observation_space.remote())
self._action_space = ray.get(
self.actors[0].action_space.remote())

# Lazy initialization. Call `reset()` on all @ray.remote
# sub-environment actors at the beginning.
Expand Down Expand Up @@ -199,9 +209,23 @@ def stop(self) -> None:

@override(BaseEnv)
@PublicAPI
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {env_id: actor for env_id, actor in enumerate(self.actors)}
return self.actors

@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space

@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space


@ray.remote(num_cpus=0)
class _RemoteMultiAgentEnv:
Expand All @@ -221,6 +245,14 @@ def reset(self):
def step(self, action_dict):
return self.env.step(action_dict)

# defining these 2 functions that way this information can be queried
# with a call to ray.get()
def observation_space(self):
return self.env.observation_space

def action_space(self):
return self.env.action_space


@ray.remote(num_cpus=0)
class _RemoteSingleAgentEnv:
Expand All @@ -243,3 +275,11 @@ def step(self, action):
} for x in [obs, rew, done, info]]
done["__all__"] = done[_DUMMY_AGENT_ID]
return obs, rew, done, info

# defining these 2 functions that way this information can be queried
# with a call to ray.get()
def observation_space(self):
return self.env.observation_space

def action_space(self):
return self.env.action_space
51 changes: 46 additions & 5 deletions rllib/env/vector_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import gym
import numpy as np
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union

from ray.rllib.env.base_env import BaseEnv
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
Expand Down Expand Up @@ -265,13 +265,13 @@ class VectorEnvWrapper(BaseEnv):

def __init__(self, vector_env: VectorEnv):
self.vector_env = vector_env
self.action_space = vector_env.action_space
self.observation_space = vector_env.observation_space
self.num_envs = vector_env.num_envs
self.new_obs = None # lazily initialized
self.cur_rewards = [None for _ in range(self.num_envs)]
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)]
self._observation_space = vector_env.observation_space
self._action_space = vector_env.action_space

@override(BaseEnv)
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
Expand Down Expand Up @@ -312,10 +312,51 @@ def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict:
}

@override(BaseEnv)
def get_sub_environments(self) -> List[EnvType]:
return self.vector_env.get_sub_environments()
def get_sub_environments(
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
if not as_dict:
return self.vector_env.get_sub_environments()
else:
return {
_id: env
for _id, env in enumerate(
self.vector_env.get_sub_environments())
}

@override(BaseEnv)
def try_render(self, env_id: Optional[EnvID] = None) -> None:
assert env_id is None or isinstance(env_id, int)
return self.vector_env.try_render_at(env_id)

@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space

@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space

@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.

Args:
space: The space to if x's observations are contained in.
x: The observations to check.

Note: With vector envs, we can process the raw observations
and ignore the agent ids and env ids, since vector envs'
sub environements are guaranteed to be the same

Returns:
True if the observations of x are contained in space.
"""
for _, multi_agent_dict in x.items():
for _, element in multi_agent_dict.items():
if not space.contains(element):
return False
return True