-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] [WIP] [MultiAgentEnv Refactor #1] Add new methods to base env #21027
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import logging | ||
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\ | ||
Union | ||
|
||
|
@@ -15,6 +16,8 @@ | |
|
||
ASYNC_RESET_RETURN = "async_reset_return" | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@PublicAPI | ||
class BaseEnv: | ||
|
@@ -194,6 +197,16 @@ def get_sub_environments( | |
return {} | ||
return [] | ||
|
||
@PublicAPI | ||
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]: | ||
"""Return the agent ids for each sub-environment. | ||
|
||
Returns: | ||
A dict mapping from env_id to a list of agent_ids. | ||
""" | ||
logger.warning("get_agent_ids() has not been implemented") | ||
return {} | ||
|
||
@PublicAPI | ||
def try_render(self, env_id: Optional[EnvID] = None) -> None: | ||
"""Tries to render the sub-environment with the given id or all. | ||
|
@@ -245,10 +258,72 @@ def action_space(self) -> gym.Space: | |
""" | ||
raise NotImplementedError | ||
|
||
@PublicAPI | ||
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict: | ||
"""Returns a random action for each environment, and potentially each | ||
agent in that environment. | ||
|
||
Args: | ||
agent_id: List of agent ids to sample actions for. If None or empty | ||
list, sample actions for all agents in the environment. | ||
|
||
Returns: | ||
A random action for each environment. | ||
""" | ||
del agent_id | ||
return {} | ||
|
||
@PublicAPI | ||
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict: | ||
"""Returns a random observation for each environment, and potentially | ||
each agent in that environment. | ||
|
||
Args: | ||
agent_id: List of agent ids to sample actions for. If None or empty | ||
list, sample actions for all agents in the environment. | ||
|
||
Returns: | ||
A random action for each environment. | ||
""" | ||
logger.warning("observation_space_sample() has not been implemented") | ||
return {} | ||
|
||
@PublicAPI | ||
def last(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, | ||
MultiEnvDict, MultiEnvDict]: | ||
"""Returns the last observations, rewards, and done flags that were | ||
returned by the environment. | ||
|
||
Returns: | ||
The last observations, rewards, and done flags for each environment | ||
""" | ||
logger.warning("last has not been implemented for this environment.") | ||
return {}, {}, {}, {}, {} | ||
|
||
@PublicAPI | ||
def observation_space_contains(self, x: MultiEnvDict) -> bool: | ||
"""Checks if the given observation is valid for each environment. | ||
|
||
Args: | ||
x: Observations to check. | ||
|
||
Returns: | ||
True if the observations are contained within their respective | ||
spaces. False otherwise. | ||
""" | ||
self._space_contains(self.observation_space, x) | ||
|
||
@PublicAPI | ||
def action_space_contains(self, x: MultiEnvDict) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a very minor question, do you envision obs/action_space_contains() getting used outside of the environment checking module? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can imagine it having other uses for users who are trying to develop their own environments -- I have definitely used functions like this while developing my own environments |
||
"""Checks if the given actions is valid for each environment. | ||
|
||
Args: | ||
x: Actions to check. | ||
|
||
Returns: | ||
True if the actions are contained within their respective | ||
spaces. False otherwise. | ||
""" | ||
return self._space_contains(self.action_space, x) | ||
|
||
@staticmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, is the type still MultiEnvDict here and below??
I thought we are saying gym and multi-agent envs return different types now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes -- the return type of poll() and try_reset() are MultiEnvDicts, and so I thought it would be appropriate if observations and actions produced by the environment/policy should be able to be easily passed to the environment for checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, it seems to me we are settling on using multi-agent api for single agent env as well, which is totally fine, and probably logical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BaseEnv is never a single-agent env. If there is only one agent and we derive the BaseEnv from e.g. a gym.Env, we auto-create "DUMMY_AGENT_ID" in the env as the agent's ID.