Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] [WIP] [MultiAgentEnv Refactor #1] Add new methods to base env #21027

Merged
merged 1 commit into from
Dec 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union

Expand All @@ -15,6 +16,8 @@

ASYNC_RESET_RETURN = "async_reset_return"

logger = logging.getLogger(__name__)


@PublicAPI
class BaseEnv:
Expand Down Expand Up @@ -194,6 +197,16 @@ def get_sub_environments(
return {}
return []

@PublicAPI
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
"""Return the agent ids for each sub-environment.

Returns:
A dict mapping from env_id to a list of agent_ids.
"""
logger.warning("get_agent_ids() has not been implemented")
return {}

@PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None:
"""Tries to render the sub-environment with the given id or all.
Expand Down Expand Up @@ -245,10 +258,72 @@ def action_space(self) -> gym.Space:
"""
raise NotImplementedError

@PublicAPI
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
"""Returns a random action for each environment, and potentially each
agent in that environment.

Args:
agent_id: List of agent ids to sample actions for. If None or empty
list, sample actions for all agents in the environment.

Returns:
A random action for each environment.
"""
del agent_id
return {}

@PublicAPI
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
"""Returns a random observation for each environment, and potentially
each agent in that environment.

Args:
agent_id: List of agent ids to sample actions for. If None or empty
list, sample actions for all agents in the environment.

Returns:
A random action for each environment.
"""
logger.warning("observation_space_sample() has not been implemented")
return {}

@PublicAPI
def last(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
MultiEnvDict, MultiEnvDict]:
"""Returns the last observations, rewards, and done flags that were
returned by the environment.

Returns:
The last observations, rewards, and done flags for each environment
"""
logger.warning("last has not been implemented for this environment.")
return {}, {}, {}, {}, {}

@PublicAPI
def observation_space_contains(self, x: MultiEnvDict) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, is the type still MultiEnvDict here and below??
I thought we are saying gym and multi-agent envs return different types now?

Copy link
Member Author

@avnishn avnishn Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes -- the return type of poll() and try_reset() are MultiEnvDicts, and so I thought it would be appropriate if observations and actions produced by the environment/policy should be able to be easily passed to the environment for checking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, it seems to me we are settling on using multi-agent api for single agent env as well, which is totally fine, and probably logical.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BaseEnv is never a single-agent env. If there is only one agent and we derive the BaseEnv from e.g. a gym.Env, we auto-create "DUMMY_AGENT_ID" in the env as the agent's ID.

"""Checks if the given observation is valid for each environment.

Args:
x: Observations to check.

Returns:
True if the observations are contained within their respective
spaces. False otherwise.
"""
self._space_contains(self.observation_space, x)

@PublicAPI
def action_space_contains(self, x: MultiEnvDict) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a very minor question, do you envision obs/action_space_contains() getting used outside of the environment checking module?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine it having other uses for users who are trying to develop their own environments -- I have definitely used functions like this while developing my own environments

"""Checks if the given actions is valid for each environment.

Args:
x: Actions to check.

Returns:
True if the actions are contained within their respective
spaces. False otherwise.
"""
return self._space_contains(self.action_space, x)

@staticmethod
Expand Down