diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 5de90741b62c..312da04b8f46 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -478,18 +478,17 @@ class MultiAgentEnvWrapper(BaseEnv): def __init__( self, make_env: Callable[[int], EnvType], - existing_envs: MultiAgentEnv, + existing_envs: List["MultiAgentEnv"], num_envs: int, ): """Wraps MultiAgentEnv(s) into the BaseEnv API. Args: - make_env (Callable[[int], EnvType]): Factory that produces a new - MultiAgentEnv instance. Must be defined, if the number of - existing envs is less than num_envs. - existing_envs (List[MultiAgentEnv]): List of already existing - multi-agent envs. - num_envs (int): Desired num multiagent envs to have at the end in + make_env: Factory that produces a new MultiAgentEnv instance taking the + vector index as only call argument. + Must be defined, if the number of existing envs is less than num_envs. + existing_envs: List of already existing multi-agent envs. + num_envs: Desired num multiagent envs to have at the end in total. This will include the given (already created) `existing_envs`. """ @@ -503,7 +502,6 @@ def __init__( assert isinstance(env, MultiAgentEnv) self.env_states = [_MultiAgentEnvState(env) for env in self.envs] self._unwrapped_env = self.envs[0].unwrapped - self._agent_ids = self._unwrapped_env.get_agent_ids() @override(BaseEnv) def poll( @@ -597,7 +595,7 @@ def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict: @override(BaseEnv) def get_agent_ids(self) -> Set[AgentID]: - return self._agent_ids + return self.envs[0].get_agent_ids() class _MultiAgentEnvState: diff --git a/rllib/env/wrappers/pettingzoo_env.py b/rllib/env/wrappers/pettingzoo_env.py index 7a5d205ffcf7..b4e9c384b258 100644 --- a/rllib/env/wrappers/pettingzoo_env.py +++ b/rllib/env/wrappers/pettingzoo_env.py @@ -71,8 +71,9 @@ def __init__(self, env): super().__init__() self.env = env env.reset() - self._skip_env_checking = True # TODO avnishn - remove this after making - # petting zoo env compatible with check_env + # TODO (avnishn): Remove this after making petting zoo env compatible with + # check_env. + self._skip_env_checking = True # Get first observation space, assuming all agents have equal space self.observation_space = self.env.observation_space(self.env.agents[0]) @@ -145,6 +146,9 @@ def __init__(self, env): super().__init__() self.par_env = env self.par_env.reset() + # TODO (avnishn): Remove this after making petting zoo env compatible with + # check_env. + self._skip_env_checking = True # Get first observation space, assuming all agents have equal space self.observation_space = self.par_env.observation_space(self.par_env.agents[0])