Skip to content

Commit

Permalink
Expand Base env API to add necessary methods for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avnishn committed Dec 15, 2021
1 parent e485aa8 commit 956759a
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions rllib/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union

Expand All @@ -15,6 +16,8 @@

ASYNC_RESET_RETURN = "async_reset_return"

logger = logging.getLogger(__name__)


@PublicAPI
class BaseEnv:
Expand Down Expand Up @@ -194,6 +197,16 @@ def get_sub_environments(
return {}
return []

@PublicAPI
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
"""Return the agent ids for each sub-environment.
Returns:
A dict mapping from env_id to a list of agent_ids.
"""
logger.warning("get_agent_ids() has not been implemented")
return {}

@PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None:
"""Tries to render the sub-environment with the given id or all.
Expand Down Expand Up @@ -245,10 +258,72 @@ def action_space(self) -> gym.Space:
"""
raise NotImplementedError

@PublicAPI
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
"""Returns a random action for each environment, and potentially each
agent in that environment.
Args:
agent_id: List of agent ids to sample actions for. If None or empty
list, sample actions for all agents in the environment.
Returns:
A random action for each environment.
"""
del agent_id
return {}

@PublicAPI
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
"""Returns a random observation for each environment, and potentially
each agent in that environment.
Args:
agent_id: List of agent ids to sample actions for. If None or empty
list, sample actions for all agents in the environment.
Returns:
A random action for each environment.
"""
logger.warning("observation_space_sample() has not been implemented")
return {}

@PublicAPI
def last(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
MultiEnvDict, MultiEnvDict]:
"""Returns the last observations, rewards, and done flags that were
returned by the environment.
Returns:
The last observations, rewards, and done flags for each environment
"""
logger.warning("last has not been implemented for this environment.")
return {}, {}, {}, {}, {}

@PublicAPI
def observation_space_contains(self, x: MultiEnvDict) -> bool:
"""Checks if the given observation is valid for each environment.
Args:
x: Observations to check.
Returns:
True if the observations are contained within their respective
spaces. False otherwise.
"""
self._space_contains(self.observation_space, x)

@PublicAPI
def action_space_contains(self, x: MultiEnvDict) -> bool:
"""Checks if the given actions is valid for each environment.
Args:
x: Actions to check.
Returns:
True if the actions are contained within their respective
spaces. False otherwise.
"""
return self._space_contains(self.action_space, x)

@staticmethod
Expand Down

0 comments on commit 956759a

Please sign in to comment.