diff --git a/docs/_scripts/generate_gif_image.py b/docs/_scripts/generate_gif_image.py index 321363c93..ad320054a 100644 --- a/docs/_scripts/generate_gif_image.py +++ b/docs/_scripts/generate_gif_image.py @@ -11,7 +11,7 @@ def generate_data(nameline, module): dir = f"frames/{nameline}/" os.mkdir(dir) - env = module.env() + env = module.env(render_mode="rgb_array") # env = gin_rummy_v0.env() env.reset() for step in range(100): @@ -30,7 +30,7 @@ def generate_data(nameline, module): if env.terminations[agent] or env.truncations[agent]: env.reset() - ndarray = env.render(mode="rgb_array") + ndarray = env.render() # tot_size = max(ndarray.shape) # target_size = 500 # ratio = target_size / tot_size diff --git a/docs/code_examples/aec_rps.py b/docs/code_examples/aec_rps.py index 8b6b149f5..88fc3c7c2 100644 --- a/docs/code_examples/aec_rps.py +++ b/docs/code_examples/aec_rps.py @@ -1,5 +1,6 @@ import functools +import gym import numpy as np from gym.spaces import Discrete @@ -25,15 +26,17 @@ } -def env(): +def env(render_mode=None): """ The env function often wraps the environment in wrappers by default. You can find full documentation for these methods elsewhere in the developer documentation. """ - env = raw_env() + internal_render_mode = render_mode if render_mode != "ansi" else "human" + env = raw_env(render_mode=internal_render_mode) # This wrapper is only for environments which print results to the terminal - env = wrappers.CaptureStdoutWrapper(env) + if render_mode == "ansi": + env = wrappers.CaptureStdoutWrapper(env) # this wrapper helps error handling for discrete action spaces env = wrappers.AssertOutOfBoundsWrapper(env) # Provides a wide vareity of helpful user errors @@ -52,7 +55,7 @@ class raw_env(AECEnv): metadata = {"render_modes": ["human"], "name": "rps_v2"} - def __init__(self): + def __init__(self, render_mode=None): """ The init method takes in environment arguments and should define the following attributes: @@ -71,6 +74,7 @@ def __init__(self): self._observation_spaces = { agent: Discrete(4) for agent in self.possible_agents } + self.render_mode = render_mode # this cache ensures that same space object is returned for the same agent # allows action space seeding to work as expected @@ -83,11 +87,17 @@ def observation_space(self, agent): def action_space(self, agent): return Discrete(3) - def render(self, mode="human"): + def render(self): """ Renders the environment. In human mode, it can print to terminal, open up a graphical window, or open up some other display that a human can see and understand. """ + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + if len(self.agents) == 2: string = "Current state: Agent1: {} , Agent2: {}".format( MOVES[self.state[self.agents[0]]], MOVES[self.state[self.agents[1]]] @@ -203,3 +213,6 @@ def step(self, action): self.agent_selection = self._agent_selector.next() # Adds .rewards to ._cumulative_rewards self._accumulate_rewards() + + if self.render_mode == "human": + self.render() diff --git a/docs/code_examples/parallel_rps.py b/docs/code_examples/parallel_rps.py index 44384121d..cc1e7f442 100644 --- a/docs/code_examples/parallel_rps.py +++ b/docs/code_examples/parallel_rps.py @@ -1,5 +1,6 @@ import functools +import gym from gym.spaces import Discrete from pettingzoo import ParallelEnv @@ -24,15 +25,17 @@ } -def env(): +def env(render_mode=None): """ The env function often wraps the environment in wrappers by default. You can find full documentation for these methods elsewhere in the developer documentation. """ - env = raw_env() + internal_render_mode = render_mode if render_mode != "ansi" else "human" + env = raw_env(render_mode=internal_render_mode) # This wrapper is only for environments which print results to the terminal - env = wrappers.CaptureStdoutWrapper(env) + if render_mode == "ansi": + env = wrappers.CaptureStdoutWrapper(env) # this wrapper helps error handling for discrete action spaces env = wrappers.AssertOutOfBoundsWrapper(env) # Provides a wide vareity of helpful user errors @@ -41,12 +44,12 @@ def env(): return env -def raw_env(): +def raw_env(render_mode=None): """ To support the AEC API, the raw_env() function just uses the from_parallel function to convert from a ParallelEnv to an AEC env """ - env = parallel_env() + env = parallel_env(render_mode=render_mode) env = parallel_to_aec(env) return env @@ -54,7 +57,7 @@ def raw_env(): class parallel_env(ParallelEnv): metadata = {"render_modes": ["human"], "name": "rps_v2"} - def __init__(self): + def __init__(self, render_mode=None): """ The init method takes in environment arguments and should define the following attributes: - possible_agents @@ -66,6 +69,7 @@ def __init__(self): self.agent_name_mapping = dict( zip(self.possible_agents, list(range(len(self.possible_agents)))) ) + self.render_mode = render_mode # this cache ensures that same space object is returned for the same agent # allows action space seeding to work as expected @@ -78,11 +82,17 @@ def observation_space(self, agent): def action_space(self, agent): return Discrete(3) - def render(self, mode="human"): + def render(self): """ Renders the environment. In human mode, it can print to terminal, open up a graphical window, or open up some other display that a human can see and understand. """ + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + if len(self.agents) == 2: string = "Current state: Agent1: {} , Agent2: {}".format( MOVES[self.state[self.agents[0]]], MOVES[self.state[self.agents[1]]] @@ -157,4 +167,6 @@ def step(self, actions): if env_truncation: self.agents = [] + if self.render_mode == "human": + self.render() return observations, rewards, terminations, truncations, infos diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index f72c52bc4..de2170cd9 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -71,7 +71,7 @@ PettingZoo models games as *Agent Environment Cycle* (AEC) games, and thus can s `seed(seed=None)`: Reseeds the environment. `reset()` must be called after `seed()`, and before `step()`. -`render(mode='human')`: Displays a rendered frame from the environment, if supported. Alternate render modes in the default environments are `'rgb_array'` which returns a numpy array and is supported by all environments outside of classic, and `'ansi'` which returns the strings printed (specific to classic environments). +`render()`: Returns a rendered frame from the environment using render mode specified at initialization. In the case render mode is`'rgb_array'`, returns a numpy array, while with `'ansi'` returns the strings printed. There is no need to call `render()` with `human` mode. `close()`: Closes the rendering window. diff --git a/pettingzoo/atari/base_atari_env.py b/pettingzoo/atari/base_atari_env.py index 53c487b5a..ea0316c5e 100644 --- a/pettingzoo/atari/base_atari_env.py +++ b/pettingzoo/atari/base_atari_env.py @@ -39,6 +39,7 @@ def __init__( full_action_space=False, env_name=None, max_cycles=100000, + render_mode=None, auto_rom_install_path=None, ): """Initializes the `ParallelAtariEnv` class. @@ -56,6 +57,7 @@ def __init__( full_action_space, env_name, max_cycles, + render_mode, auto_rom_install_path, ) @@ -75,6 +77,7 @@ def __init__( "name": env_name, "render_fps": 60, } + self.render_mode = render_mode multi_agent_ale_py.ALEInterface.setLoggerMode("error") self.ale = multi_agent_ale_py.ALEInterface() @@ -230,12 +233,24 @@ def step(self, action_dict): } infos = {agent: {} for agent in self.possible_agents if agent in self.agents} self.agents = [agent for agent in self.agents if not terminations[agent]] + + if self.render_mode == "human": + self.render() return observations, rewards, terminations, truncations, infos - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + + assert ( + self.render_mode in self.metadata["render_modes"] + ), f"{self.render_mode} is not a valid render mode" (screen_width, screen_height) = self.ale.getScreenDims() image = self.ale.getScreenRGB() - if mode == "human": + if self.render_mode == "human": import pygame zoom_factor = 4 @@ -256,10 +271,8 @@ def render(self, mode="human"): self._screen.blit(myImage, (0, 0)) pygame.display.flip() - elif mode == "rgb_array": + elif self.render_mode == "rgb_array": return image - else: - raise ValueError("bad value for render mode") def close(self): if self._screen is not None: diff --git a/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py b/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py index f5e7589fa..034069caa 100644 --- a/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py +++ b/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py @@ -147,6 +147,7 @@ def __init__( bounce_randomness=False, max_reward=100, off_screen_penalty=-10, + render_mode=None, render_ratio=2, kernel_window_length=2, ): @@ -184,6 +185,7 @@ def __init__( low=0, high=255, shape=((self.s_height, self.s_width, 3)), dtype=np.uint8 ) + self.render_mode = render_mode self.renderOn = False # set speed @@ -258,16 +260,24 @@ def enable_render(self): self.renderOn = True self.draw() - def render(self, mode="human"): - if not self.renderOn and mode == "human": + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + + if not self.renderOn and self.render_mode == "human": # sets self.renderOn to true and initializes display self.enable_render() observation = np.array(pygame.surfarray.pixels3d(self.screen)) - if mode == "human": + if self.render_mode == "human": pygame.display.flip() return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def observe(self): @@ -357,6 +367,7 @@ def __init__(self, **kwargs): self.seed() + self.render_mode = self.env.render_mode self.agents = self.env.agents[:] self.possible_agents = self.agents[:] self._agent_selector = agent_selector(self.agents) @@ -410,8 +421,8 @@ def state(self): def close(self): self.env.close() - def render(self, mode="human"): - return self.env.render(mode) + def render(self): + return self.env.render() def step(self, action): if ( diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index ac21570ff..23915470b 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -184,6 +184,7 @@ import sys from itertools import repeat +import gym import numpy as np import pygame import pygame.gfxdraw @@ -239,6 +240,7 @@ def __init__( vector_state=True, use_typemasks=False, transformer=False, + render_mode=None, ): EzPickle.__init__( self, @@ -255,6 +257,7 @@ def __init__( vector_state, use_typemasks, transformer, + render_mode, ) # variable state space self.transformer = transformer @@ -273,6 +276,7 @@ def __init__( self.frames = 0 self.closed = False self.has_reset = False + self.render_mode = render_mode self.render_on = False # Game Constants @@ -765,6 +769,9 @@ def step(self, action): self._accumulate_rewards() self._deads_step_first() + if self.render_mode == "human": + self.render() + def enable_render(self): self.WINDOW = pygame.display.set_mode([const.SCREEN_WIDTH, const.SCREEN_HEIGHT]) # self.WINDOW = pygame.Surface((const.SCREEN_WIDTH, const.SCREEN_HEIGHT)) @@ -788,16 +795,24 @@ def draw(self): self.archer_list.draw(self.WINDOW) self.knight_list.draw(self.WINDOW) - def render(self, mode="human"): - if not self.render_on and mode == "human": + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + + if not self.render_on and self.render_mode == "human": # sets self.render_on to true and initializes display self.enable_render() observation = np.array(pygame.surfarray.pixels3d(self.WINDOW)) - if mode == "human": + if self.render_mode == "human": pygame.display.flip() return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def close(self): diff --git a/pettingzoo/butterfly/pistonball/pistonball.py b/pettingzoo/butterfly/pistonball/pistonball.py index 132eaf4cf..f813470b2 100644 --- a/pettingzoo/butterfly/pistonball/pistonball.py +++ b/pettingzoo/butterfly/pistonball/pistonball.py @@ -147,6 +147,7 @@ def __init__( ball_friction=0.3, ball_elasticity=1.5, max_cycles=125, + render_mode=None, ): EzPickle.__init__( self, @@ -159,6 +160,7 @@ def __init__( ball_friction, ball_elasticity, max_cycles, + render_mode, ) self.dt = 1.0 / FPS self.n_pistons = n_pistons @@ -221,6 +223,7 @@ def __init__( pygame.init() pymunk.pygame_util.positive_y_is_up = False + self.render_mode = render_mode self.renderOn = False self.screen = pygame.Surface((self.screen_width, self.screen_height)) self.max_cycles = max_cycles @@ -593,8 +596,14 @@ def get_local_reward(self, prev_position, curr_position): local_reward = 0.5 * (prev_position - curr_position) return local_reward - def render(self, mode="human"): - if mode == "human" and not self.renderOn: + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + + if self.render_mode == "human" and not self.renderOn: # sets self.renderOn to true and initializes display self.enable_render() @@ -602,10 +611,12 @@ def render(self, mode="human"): self.draw() observation = np.array(pygame.surfarray.pixels3d(self.screen)) - if mode == "human": + if self.render_mode == "human": pygame.display.flip() return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def step(self, action): @@ -671,5 +682,8 @@ def step(self, action): self._cumulative_rewards[agent] = 0 self._accumulate_rewards() + if self.render_mode == "human": + self.render() + # Game art created by J K Terry diff --git a/pettingzoo/classic/chess/chess.py b/pettingzoo/classic/chess/chess.py index 94fad2e29..c036e131c 100644 --- a/pettingzoo/classic/chess/chess.py +++ b/pettingzoo/classic/chess/chess.py @@ -89,6 +89,7 @@ """ import chess +import gym import numpy as np from gym import spaces @@ -99,9 +100,11 @@ from . import chess_utils -def env(): - env = raw_env() - env = wrappers.CaptureStdoutWrapper(env) +def env(render_mode=None): + internal_render_mode = render_mode if render_mode != "ansi" else "human" + env = raw_env(render_mode=internal_render_mode) + if render_mode == "ansi": + env = wrappers.CaptureStdoutWrapper(env) env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) @@ -117,7 +120,7 @@ class raw_env(AECEnv): "render_fps": 2, } - def __init__(self): + def __init__(self, render_mode=None): super().__init__() self.board = chess.Board() @@ -151,6 +154,8 @@ def __init__(self): self.board_history = np.zeros((8, 8, 104), dtype=bool) + self.render_mode = render_mode + def observation_space(self, agent): return self.observation_spaces[agent] @@ -236,8 +241,16 @@ def step(self, action): self._agent_selector.next() ) # Give turn to the next agent - def render(self, mode="human"): - print(self.board) + if self.render_mode == "human": + self.render() + + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + else: + print(self.board) def close(self): pass diff --git a/pettingzoo/classic/connect_four/connect_four.py b/pettingzoo/classic/connect_four/connect_four.py index d965424d4..680330175 100644 --- a/pettingzoo/classic/connect_four/connect_four.py +++ b/pettingzoo/classic/connect_four/connect_four.py @@ -64,6 +64,7 @@ import os +import gym import numpy as np import pygame from gym import spaces @@ -85,8 +86,8 @@ def get_image(path): return sfc -def env(): - env = raw_env() +def env(render_mode=None): + env = raw_env(render_mode=render_mode) env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) @@ -101,7 +102,7 @@ class raw_env(AECEnv): "render_fps": 2, } - def __init__(self): + def __init__(self, render_mode=None): super().__init__() # 6 rows x 7 columns # blank space = 0 @@ -109,6 +110,7 @@ def __init__(self): # agent 1 -- 2 # flat representation in row major order self.screen = None + self.render_mode = render_mode self.board = [0] * (6 * 7) @@ -202,6 +204,9 @@ def step(self, action): self._accumulate_rewards() + if self.render_mode == "human": + self.render() + def reset(self, seed=None, return_info=False, options=None): # reset environment self.board = [0] * (6 * 7) @@ -217,17 +222,22 @@ def reset(self, seed=None, return_info=False, options=None): self.agent_selection = self._agent_selector.reset() - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + screen_width = 1287 screen_height = 1118 - if self.screen is None: - if mode == "human": + if self.render_mode == "human": + if self.screen is None: pygame.init() self.screen = pygame.display.set_mode((screen_width, screen_height)) - else: - self.screen = pygame.Surface((screen_width, screen_height)) - if mode == "human": pygame.event.get() + elif self.screen is None: + self.screen = pygame.Surface((screen_width, screen_height)) # Load and scale all of the necessary images tile_size = (screen_width * (91 / 99)) / 7 @@ -268,13 +278,15 @@ def render(self, mode="human"): ), ) - if mode == "human": + if self.render_mode == "human": pygame.display.update() observation = np.array(pygame.surfarray.pixels3d(self.screen)) return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def close(self): diff --git a/pettingzoo/classic/go/go.py b/pettingzoo/classic/go/go.py index 78768c98d..62dfd297b 100644 --- a/pettingzoo/classic/go/go.py +++ b/pettingzoo/classic/go/go.py @@ -112,7 +112,9 @@ """ import os +from typing import Optional +import gym import numpy as np import pygame from gym import spaces @@ -153,7 +155,9 @@ class raw_env(AECEnv): "render_fps": 2, } - def __init__(self, board_size: int = 19, komi: float = 7.5): + def __init__( + self, board_size: int = 19, komi: float = 7.5, render_mode: Optional[str] = None + ): # board_size: a int, representing the board size (board has a board_size x board_size shape) # komi: a float, representing points given to the second player. super().__init__() @@ -194,6 +198,8 @@ def __init__(self, board_size: int = 19, komi: float = 7.5): self.board_history = np.zeros((self._N, self._N, 16), dtype=bool) + self.render_mode = render_mode + def observation_space(self, agent): return self.observation_spaces[agent] @@ -305,6 +311,9 @@ def step(self, action): ) self._accumulate_rewards() + if self.render_mode == "human": + self.render() + def reset(self, seed=None, return_info=False, options=None): self.has_reset = True self._go = go_base.Position(board=None, komi=self._komi) @@ -325,17 +334,23 @@ def reset(self, seed=None, return_info=False, options=None): self._last_obs = self.observe(self.agents[0]) self.board_history = np.zeros((self._N, self._N, 16), dtype=bool) - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + screen_width = 1026 screen_height = 1026 if self.screen is None: - if mode == "human": + if self.render_mode == "human": pygame.init() self.screen = pygame.display.set_mode((screen_width, screen_height)) else: self.screen = pygame.Surface((screen_width, screen_height)) - if mode == "human": + if self.render_mode == "human": pygame.event.get() size = go_base.N @@ -407,13 +422,15 @@ def render(self, mode="human"): ((i * (tile_size) + offset), int(j) * (tile_size) + offset), ) - if mode == "human": + if self.render_mode == "human": pygame.display.update() observation = np.array(pygame.surfarray.pixels3d(self.screen)) return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def close(self): diff --git a/pettingzoo/classic/hanabi/hanabi.py b/pettingzoo/classic/hanabi/hanabi.py index 089cb2ea8..a15543b20 100644 --- a/pettingzoo/classic/hanabi/hanabi.py +++ b/pettingzoo/classic/hanabi/hanabi.py @@ -163,6 +163,7 @@ from typing import Dict, List, Optional, Union +import gym import numpy as np from gym import spaces from gym.utils import EzPickle @@ -196,11 +197,15 @@ def __float__(self): def env(**kwargs): - env = r_env = raw_env(**kwargs) - env = wrappers.CaptureStdoutWrapper(env) - env = wrappers.TerminateIllegalWrapper( - env, illegal_reward=HanabiScorePenalty(r_env) - ) + render_mode = kwargs.get("render_mode") + if render_mode == "ansi": + kwargs["render_mode"] = "human" + env = raw_env(**kwargs) + env = wrappers.CaptureStdoutWrapper(env) + else: + env = raw_env(**kwargs) + + env = wrappers.TerminateIllegalWrapper(env, illegal_reward=HanabiScorePenalty(env)) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) return env @@ -238,6 +243,7 @@ def __init__( max_life_tokens: int = 3, observation_type: int = 1, random_start_player: bool = False, + render_mode: Optional[str] = None, ): """Initializes the `raw_env` class. @@ -294,6 +300,7 @@ def __init__( max_life_tokens, observation_type, random_start_player, + render_mode, ) # ToDo: Starts @@ -354,6 +361,8 @@ def __init__( for player_name in self.agents } + self.render_mode = render_mode + def observation_space(self, agent): return self.observation_spaces[agent] @@ -544,11 +553,17 @@ def _process_latest_observations( for player_name in self.agents } - def render(self, mode="human"): + def render(self): """Prints player's data. Supports console print only. """ + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + player_data = self.latest_observations["player_observations"] print( "Active player:", diff --git a/pettingzoo/classic/rlcard_envs/gin_rummy.py b/pettingzoo/classic/rlcard_envs/gin_rummy.py index 68a23a807..feb4eaf1f 100644 --- a/pettingzoo/classic/rlcard_envs/gin_rummy.py +++ b/pettingzoo/classic/rlcard_envs/gin_rummy.py @@ -113,6 +113,7 @@ """ +import gym import numpy as np from gym.utils import EzPickle from rlcard.games.gin_rummy.player import GinRummyPlayer @@ -127,8 +128,13 @@ def env(**kwargs): - env = raw_env(**kwargs) - env = wrappers.CaptureStdoutWrapper(env) + render_mode = kwargs.get("render_mode") + if render_mode == "ansi": + kwargs["render_mode"] = "human" + env = raw_env(**kwargs) + env = wrappers.CaptureStdoutWrapper(env) + else: + env = raw_env(**kwargs) env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) @@ -149,8 +155,9 @@ def __init__( knock_reward: float = 0.5, gin_reward: float = 1.0, opponents_hand_visible=False, + render_mode=None, ): - EzPickle.__init__(self, knock_reward, gin_reward) + EzPickle.__init__(self, knock_reward, gin_reward, render_mode) self._opponents_hand_visible = opponents_hand_visible num_planes = 5 if self._opponents_hand_visible else 4 RLCardBase.__init__(self, "gin-rummy", 2, (num_planes, 52)) @@ -158,6 +165,7 @@ def __init__( self._gin_reward = gin_reward self.env.game.judge.scorer.get_payoff = self._get_payoff + self.render_mode = render_mode def _get_payoff(self, player: GinRummyPlayer, game) -> float: going_out_action = game.round.going_out_action @@ -194,7 +202,13 @@ def observe(self, agent): return {"observation": observation, "action_mask": action_mask} - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + for player in self.possible_agents: state = self.env.game.round.players[self._name_to_int(player)].hand print(f"\n===== {player}'s Hand =====") diff --git a/pettingzoo/classic/rlcard_envs/leduc_holdem.py b/pettingzoo/classic/rlcard_envs/leduc_holdem.py index 2a921d291..b8567cb75 100644 --- a/pettingzoo/classic/rlcard_envs/leduc_holdem.py +++ b/pettingzoo/classic/rlcard_envs/leduc_holdem.py @@ -85,6 +85,7 @@ """ +import gym from rlcard.utils.utils import print_card from pettingzoo.utils import wrappers @@ -93,8 +94,13 @@ def env(**kwargs): - env = raw_env(**kwargs) - env = wrappers.CaptureStdoutWrapper(env) + render_mode = kwargs.get("render_mode") + if render_mode == "ansi": + kwargs["render_mode"] = "human" + env = raw_env(**kwargs) + env = wrappers.CaptureStdoutWrapper(env) + else: + env = raw_env(**kwargs) env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) @@ -110,10 +116,17 @@ class raw_env(RLCardBase): "render_fps": 1, } - def __init__(self, num_players=2): + def __init__(self, num_players=2, render_mode=None): super().__init__("leduc-holdem", num_players, (36,)) + self.render_mode = render_mode + + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return - def render(self, mode="human"): for player in self.possible_agents: state = self.env.game.get_state(self._name_to_int(player)) print(f"\n=============== {player}'s Hand ===============") diff --git a/pettingzoo/classic/rlcard_envs/rlcard_base.py b/pettingzoo/classic/rlcard_envs/rlcard_base.py index ac9519ce6..68df199a3 100644 --- a/pettingzoo/classic/rlcard_envs/rlcard_base.py +++ b/pettingzoo/classic/rlcard_envs/rlcard_base.py @@ -135,7 +135,7 @@ def reset(self, seed=None, return_info=False, options=None): self.next_legal_moves = list(sorted(obs["legal_actions"])) self._last_obs = obs["obs"] - def render(self, mode="human"): + def render(self): raise NotImplementedError() def close(self): diff --git a/pettingzoo/classic/rlcard_envs/texas_holdem.py b/pettingzoo/classic/rlcard_envs/texas_holdem.py index d313494eb..8589e12a0 100644 --- a/pettingzoo/classic/rlcard_envs/texas_holdem.py +++ b/pettingzoo/classic/rlcard_envs/texas_holdem.py @@ -83,6 +83,7 @@ import os +import gym import numpy as np import pygame @@ -126,10 +127,17 @@ class raw_env(RLCardBase): "render_fps": 1, } - def __init__(self, num_players=2): + def __init__(self, num_players=2, render_mode=None): super().__init__("limit-holdem", num_players, (72,)) + self.render_mode = render_mode + + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return - def render(self, mode="human"): def calculate_width(self, screen_width, i): return int( ( @@ -154,15 +162,14 @@ def calculate_height(screen_height, divisor, multiplier, tile_size, offset): + np.ceil(len(self.possible_agents) / 2) * (screen_height * 1 / 2) ) - if self.screen is None: - if mode == "human": + if self.render_mode == "human": + if self.screen is None: pygame.init() self.screen = pygame.display.set_mode((screen_width, screen_height)) - else: - pygame.font.init() - self.screen = pygame.Surface((screen_width, screen_height)) - if mode == "human": pygame.event.get() + elif self.screen is None: + pygame.font.init() + self.screen = pygame.Surface((screen_width, screen_height)) # Setup dimensions for card size and setup for colors tile_size = screen_height * 2 / 10 @@ -357,11 +364,13 @@ def calculate_height(screen_height, divisor, multiplier, tile_size, offset): ), ) - if mode == "human": + if self.render_mode == "human": pygame.display.update() observation = np.array(pygame.surfarray.pixels3d(self.screen)) return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) diff --git a/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py b/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py index b8af53fde..1e7162a15 100644 --- a/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py +++ b/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py @@ -95,6 +95,7 @@ import os +import gym import numpy as np import pygame from gym import spaces @@ -139,7 +140,7 @@ class raw_env(RLCardBase): "render_fps": 1, } - def __init__(self, num_players=2): + def __init__(self, num_players=2, render_mode=None): super().__init__("no-limit-holdem", num_players, (54,)) self.observation_spaces = self._convert_to_dict( [ @@ -166,7 +167,15 @@ def __init__(self, num_players=2): ] ) - def render(self, mode="human"): + self.render_mode = render_mode + + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + def calculate_width(self, screen_width, i): return int( ( @@ -191,15 +200,14 @@ def calculate_height(screen_height, divisor, multiplier, tile_size, offset): + np.ceil(len(self.possible_agents) / 2) * (screen_height * 1 / 2) ) - if self.screen is None: - if mode == "human": + if self.render_mode == "human": + if self.screen is None: pygame.init() self.screen = pygame.display.set_mode((screen_width, screen_height)) - else: - pygame.font.init() - self.screen = pygame.Surface((screen_width, screen_height)) - if mode == "human": pygame.event.get() + elif self.screen is None: + pygame.font.init() + self.screen = pygame.Surface((screen_width, screen_height)) # Setup dimensions for card size and setup for colors tile_size = screen_height * 2 / 10 @@ -394,11 +402,13 @@ def calculate_height(screen_height, divisor, multiplier, tile_size, offset): ), ) - if mode == "human": + if self.render_mode == "human": pygame.display.update() observation = np.array(pygame.surfarray.pixels3d(self.screen)) return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) diff --git a/pettingzoo/classic/rps/rps.py b/pettingzoo/classic/rps/rps.py index 3b95dd0b7..98e49bf78 100644 --- a/pettingzoo/classic/rps/rps.py +++ b/pettingzoo/classic/rps/rps.py @@ -117,6 +117,7 @@ import os +import gym import numpy as np import pygame from gym.spaces import Discrete @@ -170,7 +171,7 @@ class raw_env(AECEnv): "render_fps": 2, } - def __init__(self, num_actions=3, max_cycles=15): + def __init__(self, num_actions=3, max_cycles=15, render_mode=None): self.max_cycles = max_cycles # number of actions must be odd and greater than 3 @@ -194,6 +195,7 @@ def __init__(self, num_actions=3, max_cycles=15): agent: Discrete(1 + num_actions) for agent in self.agents } + self.render_mode = render_mode self.screen = None self.reinit() @@ -221,7 +223,13 @@ def reinit(self): self.num_moves = 0 - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + def offset(i, size, offset=0): if i == 0: return -(size) - offset @@ -231,15 +239,14 @@ def offset(i, size, offset=0): screen_height = 350 screen_width = int(screen_height * 5 / 14) - if self.screen is None: - if mode == "human": + if self.render_mode == "human": + if self.screen is None: pygame.init() self.screen = pygame.display.set_mode((screen_width, screen_height)) - else: - pygame.font.init() - self.screen = pygame.Surface((screen_width, screen_height)) - if mode == "human": pygame.event.get() + elif self.screen is None: + pygame.font.init() + self.screen = pygame.Surface((screen_width, screen_height)) # Load and all of the necessary images paper = get_image(os.path.join("img", "Paper.png")) @@ -413,13 +420,15 @@ def offset(i, size, offset=0): ), ) - if mode == "human": + if self.render_mode == "human": pygame.display.update() observation = np.array(pygame.surfarray.pixels3d(self.screen)) return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def observe(self, agent): @@ -488,3 +497,6 @@ def step(self, action): self._cumulative_rewards[self.agent_selection] = 0 self.agent_selection = self._agent_selector.next() self._accumulate_rewards() + + if self.render_mode == "human": + self.render() diff --git a/pettingzoo/classic/tictactoe/tictactoe.py b/pettingzoo/classic/tictactoe/tictactoe.py index 1a716c8f3..b9e3fe0a9 100644 --- a/pettingzoo/classic/tictactoe/tictactoe.py +++ b/pettingzoo/classic/tictactoe/tictactoe.py @@ -73,6 +73,7 @@ """ +import gym import numpy as np from gym import spaces @@ -82,9 +83,11 @@ from .board import Board -def env(): - env = raw_env() - env = wrappers.CaptureStdoutWrapper(env) +def env(render_mode=None): + internal_render_mode = render_mode if render_mode != "ansi" else "human" + env = raw_env(render_mode=internal_render_mode) + if render_mode == "ansi": + env = wrappers.CaptureStdoutWrapper(env) env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) @@ -99,7 +102,7 @@ class raw_env(AECEnv): "render_fps": 1, } - def __init__(self): + def __init__(self, render_mode=None): super().__init__() self.board = Board() @@ -127,6 +130,8 @@ def __init__(self): self._agent_selector = agent_selector(self.agents) self.agent_selection = self._agent_selector.reset() + self.render_mode = render_mode + # Key # ---- # blank space = 0 @@ -203,6 +208,8 @@ def step(self, action): self.agent_selection = next_agent self._accumulate_rewards() + if self.render_mode == "human": + self.render() def reset(self, seed=None, return_info=False, options=None): # reset environment @@ -219,7 +226,13 @@ def reset(self, seed=None, return_info=False, options=None): self._agent_selector.reset() self.agent_selection = self._agent_selector.reset() - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + def getSymbol(input): if input == 0: return "-" diff --git a/pettingzoo/magent/adversarial_pursuit/adversarial_pursuit.py b/pettingzoo/magent/adversarial_pursuit/adversarial_pursuit.py index dbda4c8e5..3054734c2 100644 --- a/pettingzoo/magent/adversarial_pursuit/adversarial_pursuit.py +++ b/pettingzoo/magent/adversarial_pursuit/adversarial_pursuit.py @@ -127,12 +127,13 @@ def parallel_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **reward_args ): env_reward_args = dict(**default_reward_args) env_reward_args.update(reward_args) return _parallel_env( - map_size, minimap_mode, env_reward_args, max_cycles, extra_features + map_size, minimap_mode, env_reward_args, max_cycles, extra_features, render_mode ) @@ -198,9 +199,23 @@ class _parallel_env(magent_parallel_env, EzPickle): "render_fps": 5, } - def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): + def __init__( + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode=None, + ): EzPickle.__init__( - self, map_size, minimap_mode, reward_args, max_cycles, extra_features + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode, ) assert map_size >= 7, "size of map must be at least 7" env = magent.GridWorld( @@ -223,6 +238,7 @@ def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_featur reward_range, minimap_mode, extra_features, + render_mode, ) def generate_map(self): diff --git a/pettingzoo/magent/battle/battle.py b/pettingzoo/magent/battle/battle.py index 432567e87..9e7e606fb 100644 --- a/pettingzoo/magent/battle/battle.py +++ b/pettingzoo/magent/battle/battle.py @@ -146,12 +146,13 @@ def parallel_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **reward_args ): env_reward_args = dict(**default_reward_args) env_reward_args.update(reward_args) return _parallel_env( - map_size, minimap_mode, env_reward_args, max_cycles, extra_features + map_size, minimap_mode, env_reward_args, max_cycles, extra_features, render_mode ) @@ -225,9 +226,23 @@ class _parallel_env(magent_parallel_env, EzPickle): "render_fps": 5, } - def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): + def __init__( + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode=None, + ): EzPickle.__init__( - self, map_size, minimap_mode, reward_args, max_cycles, extra_features + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode, ) assert map_size >= 12, "size of map must be at least 12" env = magent.GridWorld( @@ -250,6 +265,7 @@ def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_featur reward_range, minimap_mode, extra_features, + render_mode, ) def generate_map(self): diff --git a/pettingzoo/magent/battlefield/battlefield.py b/pettingzoo/magent/battlefield/battlefield.py index 78078c2a5..82bc0fec1 100644 --- a/pettingzoo/magent/battlefield/battlefield.py +++ b/pettingzoo/magent/battlefield/battlefield.py @@ -148,12 +148,13 @@ def parallel_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **reward_args ): env_reward_args = dict(**default_reward_args) env_reward_args.update(reward_args) return _parallel_env( - map_size, minimap_mode, env_reward_args, max_cycles, extra_features + map_size, minimap_mode, env_reward_args, max_cycles, extra_features, render_mode ) @@ -179,9 +180,23 @@ class _parallel_env(magent_parallel_env, EzPickle): "render_fps": 5, } - def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): + def __init__( + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode=None, + ): EzPickle.__init__( - self, map_size, minimap_mode, reward_args, max_cycles, extra_features + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode, ) assert map_size >= 46, "size of map must be at least 46" env = magent.GridWorld( @@ -204,6 +219,7 @@ def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_featur reward_range, minimap_mode, extra_features, + render_mode, ) def generate_map(self): diff --git a/pettingzoo/magent/combined_arms/combined_arms.py b/pettingzoo/magent/combined_arms/combined_arms.py index a8bf244a3..f218b58a5 100644 --- a/pettingzoo/magent/combined_arms/combined_arms.py +++ b/pettingzoo/magent/combined_arms/combined_arms.py @@ -151,12 +151,13 @@ def parallel_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **reward_args ): env_reward_args = dict(**default_reward_args) env_reward_args.update(reward_args) return _parallel_env( - map_size, minimap_mode, env_reward_args, max_cycles, extra_features + map_size, minimap_mode, env_reward_args, max_cycles, extra_features, render_mode ) @@ -165,10 +166,18 @@ def raw_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **reward_args ): return parallel_to_aec_wrapper( - parallel_env(map_size, max_cycles, minimap_mode, extra_features, **reward_args) + parallel_env( + map_size, + max_cycles, + minimap_mode, + extra_features, + render_mode=render_mode, + **reward_args + ) ) @@ -360,7 +369,15 @@ class _parallel_env(magent_parallel_env, EzPickle): "render_fps": 5, } - def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): + def __init__( + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode=None, + ): EzPickle.__init__( self, map_size, minimap_mode, reward_args, max_cycles, extra_features ) @@ -381,6 +398,7 @@ def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_featur reward_range, minimap_mode, extra_features, + render_mode, ) def generate_map(self): diff --git a/pettingzoo/magent/gather/gather.py b/pettingzoo/magent/gather/gather.py index 9785464c9..7c982d420 100644 --- a/pettingzoo/magent/gather/gather.py +++ b/pettingzoo/magent/gather/gather.py @@ -131,12 +131,13 @@ def parallel_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **reward_args ): env_reward_args = dict(**default_reward_args) env_reward_args.update(reward_args) return _parallel_env( - map_size, minimap_mode, env_reward_args, max_cycles, extra_features + map_size, minimap_mode, env_reward_args, max_cycles, extra_features, render_mode ) @@ -209,7 +210,15 @@ class _parallel_env(magent_parallel_env, EzPickle): "render_fps": 5, } - def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): + def __init__( + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode=None, + ): EzPickle.__init__( self, map_size, minimap_mode, reward_args, max_cycles, extra_features ) @@ -230,6 +239,7 @@ def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_featur reward_range, minimap_mode, extra_features, + render_mode, ) def generate_map(self): diff --git a/pettingzoo/magent/magent_env.py b/pettingzoo/magent/magent_env.py index 1100c23ca..2912015d3 100644 --- a/pettingzoo/magent/magent_env.py +++ b/pettingzoo/magent/magent_env.py @@ -1,3 +1,4 @@ +import gym import numpy as np from gym.spaces import Box, Discrete from gym.utils import seeding @@ -28,6 +29,7 @@ def __init__( reward_range, minimap_mode, extra_features, + render_mode=None, ): self.map_size = map_size self.max_cycles = max_cycles @@ -93,6 +95,7 @@ def __init__( walls = self.env._get_walls_info() wall_x, wall_y = zip(*walls) self.base_state[wall_x, wall_y, 0] = 1 + self.render_mode = render_mode self._renderer = None self.frames = 0 @@ -136,13 +139,19 @@ def _calc_state_shape(self): return (self.map_size, self.map_size, state_depth) - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + if self._renderer is None: - self._renderer = Renderer(self.env, self.map_size, mode) + self._renderer = Renderer(self.env, self.map_size, self.render_mode) assert ( - mode == self._renderer.mode + self.render_mode == self._renderer.mode ), "mode must be consistent across render calls" - return self._renderer.render(mode) + return self._renderer.render(self.render_mode) def close(self): if self._renderer is not None: @@ -276,4 +285,7 @@ def step(self, all_actions): for agent in self.agents if not (terminations[agent] or truncations[agent]) ] + + if self.render_mode == "human": + self.render() return observations, rewards, terminations, truncations, infos diff --git a/pettingzoo/magent/tiger_deer/tiger_deer.py b/pettingzoo/magent/tiger_deer/tiger_deer.py index bf7a5d2a7..b2304595c 100644 --- a/pettingzoo/magent/tiger_deer/tiger_deer.py +++ b/pettingzoo/magent/tiger_deer/tiger_deer.py @@ -112,12 +112,13 @@ def parallel_env( max_cycles=max_cycles_default, minimap_mode=minimap_mode_default, extra_features=False, + render_mode=None, **env_args ): env_env_args = dict(**default_env_args) env_env_args.update(env_args) return _parallel_env( - map_size, minimap_mode, env_env_args, max_cycles, extra_features + map_size, minimap_mode, env_env_args, max_cycles, extra_features, render_mode ) @@ -197,9 +198,23 @@ class _parallel_env(magent_parallel_env, EzPickle): "render_fps": 5, } - def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features): + def __init__( + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode=None, + ): EzPickle.__init__( - self, map_size, minimap_mode, reward_args, max_cycles, extra_features + self, + map_size, + minimap_mode, + reward_args, + max_cycles, + extra_features, + render_mode, ) assert map_size >= 10, "size of map must be at least 10" env = magent.GridWorld( @@ -223,6 +238,7 @@ def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_featur reward_range, minimap_mode, extra_features, + render_mode, ) def generate_map(self): diff --git a/pettingzoo/mpe/_mpe_utils/simple_env.py b/pettingzoo/mpe/_mpe_utils/simple_env.py index 8e9d39179..103a8a046 100644 --- a/pettingzoo/mpe/_mpe_utils/simple_env.py +++ b/pettingzoo/mpe/_mpe_utils/simple_env.py @@ -1,5 +1,6 @@ import os +import gym import numpy as np import pygame from gym import spaces @@ -27,11 +28,24 @@ def env(**kwargs): class SimpleEnv(AECEnv): + metadata = { + "render_modes": ["human", "rgb_array"], + "is_parallelizable": True, + "render_fps": 10, + } + def __init__( - self, scenario, world, max_cycles, continuous_actions=False, local_ratio=None + self, + scenario, + world, + max_cycles, + render_mode=None, + continuous_actions=False, + local_ratio=None, ): super().__init__() + self.render_mode = render_mode pygame.init() self.viewer = None self.width = 700 @@ -47,12 +61,6 @@ def __init__( self.renderOn = False self.seed() - self.metadata = { - "render_modes": ["human", "rgb_array"], - "is_parallelizable": True, - "render_fps": 10, - } - self.max_cycles = max_cycles self.scenario = scenario self.world = world @@ -251,20 +259,31 @@ def step(self, action): self._cumulative_rewards[cur_agent] = 0 self._accumulate_rewards() + if self.render_mode == "human": + self.render() + def enable_render(self, mode="human"): if not self.renderOn and mode == "human": self.screen = pygame.display.set_mode(self.screen.get_size()) self.renderOn = True - def render(self, mode="human"): - self.enable_render(mode) + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + + self.enable_render(self.render_mode) observation = np.array(pygame.surfarray.pixels3d(self.screen)) - if mode == "human": + if self.render_mode == "human": self.draw() pygame.display.flip() return ( - np.transpose(observation, axes=(1, 0, 2)) if mode == "rgb_array" else None + np.transpose(observation, axes=(1, 0, 2)) + if self.render_mode == "rgb_array" + else None ) def draw(self): diff --git a/pettingzoo/mpe/simple/simple.py b/pettingzoo/mpe/simple/simple.py index 12850cb31..06fb58edc 100644 --- a/pettingzoo/mpe/simple/simple.py +++ b/pettingzoo/mpe/simple/simple.py @@ -56,10 +56,16 @@ class raw_env(SimpleEnv): - def __init__(self, max_cycles=25, continuous_actions=False): + def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): scenario = Scenario() world = scenario.make_world() - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_v2" diff --git a/pettingzoo/mpe/simple_adversary/simple_adversary.py b/pettingzoo/mpe/simple_adversary/simple_adversary.py index 6e84a53a3..253900fb6 100644 --- a/pettingzoo/mpe/simple_adversary/simple_adversary.py +++ b/pettingzoo/mpe/simple_adversary/simple_adversary.py @@ -67,16 +67,17 @@ class raw_env(SimpleEnv, EzPickle): - def __init__(self, N=2, max_cycles=25, continuous_actions=False): - EzPickle.__init__( - self, - N, - max_cycles, - continuous_actions, - ) + def __init__(self, N=2, max_cycles=25, continuous_actions=False, render_mode=None): + EzPickle.__init__(self, N, max_cycles, continuous_actions, render_mode) scenario = Scenario() world = scenario.make_world(N) - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_adversary_v2" diff --git a/pettingzoo/mpe/simple_crypto/simple_crypto.py b/pettingzoo/mpe/simple_crypto/simple_crypto.py index 37456d94a..d7573973b 100644 --- a/pettingzoo/mpe/simple_crypto/simple_crypto.py +++ b/pettingzoo/mpe/simple_crypto/simple_crypto.py @@ -78,15 +78,17 @@ class raw_env(SimpleEnv, EzPickle): - def __init__(self, max_cycles=25, continuous_actions=False): - EzPickle.__init__( - self, - max_cycles, - continuous_actions, - ) + def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): + EzPickle.__init__(self, max_cycles, continuous_actions, render_mode) scenario = Scenario() world = scenario.make_world() - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_crypto_v2" diff --git a/pettingzoo/mpe/simple_push/simple_push.py b/pettingzoo/mpe/simple_push/simple_push.py index 7f93aee9a..2c34193a1 100644 --- a/pettingzoo/mpe/simple_push/simple_push.py +++ b/pettingzoo/mpe/simple_push/simple_push.py @@ -62,15 +62,17 @@ class raw_env(SimpleEnv, EzPickle): - def __init__(self, max_cycles=25, continuous_actions=False): - EzPickle.__init__( - self, - max_cycles, - continuous_actions, - ) + def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): + EzPickle.__init__(self, max_cycles, continuous_actions, render_mode) scenario = Scenario() world = scenario.make_world() - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_push_v2" diff --git a/pettingzoo/mpe/simple_reference/simple_reference.py b/pettingzoo/mpe/simple_reference/simple_reference.py index 22ee03d62..9d55ca87f 100644 --- a/pettingzoo/mpe/simple_reference/simple_reference.py +++ b/pettingzoo/mpe/simple_reference/simple_reference.py @@ -68,19 +68,29 @@ class raw_env(SimpleEnv, EzPickle): - def __init__(self, local_ratio=0.5, max_cycles=25, continuous_actions=False): + def __init__( + self, local_ratio=0.5, max_cycles=25, continuous_actions=False, render_mode=None + ): EzPickle.__init__( self, local_ratio, max_cycles, continuous_actions, + render_mode, ) assert ( 0.0 <= local_ratio <= 1.0 ), "local_ratio is a proportion. Must be between 0 and 1." scenario = Scenario() world = scenario.make_world() - super().__init__(scenario, world, max_cycles, continuous_actions, local_ratio) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + local_ratio=local_ratio, + ) self.metadata["name"] = "simple_reference_v2" diff --git a/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py b/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py index 1da6cb649..1ecded5ca 100644 --- a/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py +++ b/pettingzoo/mpe/simple_speaker_listener/simple_speaker_listener.py @@ -64,15 +64,17 @@ class raw_env(SimpleEnv, EzPickle): - def __init__(self, max_cycles=25, continuous_actions=False): - EzPickle.__init__( - self, - max_cycles, - continuous_actions, - ) + def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None): + EzPickle.__init__(self, max_cycles, continuous_actions, render_mode) scenario = Scenario() world = scenario.make_world() - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_speaker_listener_v3" diff --git a/pettingzoo/mpe/simple_spread/simple_spread.py b/pettingzoo/mpe/simple_spread/simple_spread.py index b8c5e3b1a..e9fb4c578 100644 --- a/pettingzoo/mpe/simple_spread/simple_spread.py +++ b/pettingzoo/mpe/simple_spread/simple_spread.py @@ -67,20 +67,30 @@ class raw_env(SimpleEnv, EzPickle): - def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False): + def __init__( + self, + N=3, + local_ratio=0.5, + max_cycles=25, + continuous_actions=False, + render_mode=None, + ): EzPickle.__init__( - self, - N, - local_ratio, - max_cycles, - continuous_actions, + self, N, local_ratio, max_cycles, continuous_actions, render_mode ) assert ( 0.0 <= local_ratio <= 1.0 ), "local_ratio is a proportion. Must be between 0 and 1." scenario = Scenario() world = scenario.make_world(N) - super().__init__(scenario, world, max_cycles, continuous_actions, local_ratio) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + local_ratio=local_ratio, + ) self.metadata["name"] = "simple_spread_v2" diff --git a/pettingzoo/mpe/simple_tag/simple_tag.py b/pettingzoo/mpe/simple_tag/simple_tag.py index 47b5250b3..cf81d7873 100644 --- a/pettingzoo/mpe/simple_tag/simple_tag.py +++ b/pettingzoo/mpe/simple_tag/simple_tag.py @@ -84,6 +84,7 @@ def __init__( num_obstacles=2, max_cycles=25, continuous_actions=False, + render_mode=None, ): EzPickle.__init__( self, @@ -92,10 +93,17 @@ def __init__( num_obstacles, max_cycles, continuous_actions, + render_mode, ) scenario = Scenario() world = scenario.make_world(num_good, num_adversaries, num_obstacles) - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_tag_v2" diff --git a/pettingzoo/mpe/simple_world_comm/simple_world_comm.py b/pettingzoo/mpe/simple_world_comm/simple_world_comm.py index e5d3ef731..6166eb08f 100644 --- a/pettingzoo/mpe/simple_world_comm/simple_world_comm.py +++ b/pettingzoo/mpe/simple_world_comm/simple_world_comm.py @@ -97,6 +97,7 @@ def __init__( max_cycles=25, num_forests=2, continuous_actions=False, + render_mode=None, ): EzPickle.__init__( self, @@ -106,12 +107,19 @@ def __init__( max_cycles, num_forests, continuous_actions, + render_mode, ) scenario = Scenario() world = scenario.make_world( num_good, num_adversaries, num_obstacles, num_food, num_forests ) - super().__init__(scenario, world, max_cycles, continuous_actions) + super().__init__( + scenario=scenario, + world=world, + render_mode=render_mode, + max_cycles=max_cycles, + continuous_actions=continuous_actions, + ) self.metadata["name"] = "simple_world_comm_v2" diff --git a/pettingzoo/sisl/multiwalker/multiwalker.py b/pettingzoo/sisl/multiwalker/multiwalker.py index c2e67d26c..e8fc7aa56 100755 --- a/pettingzoo/sisl/multiwalker/multiwalker.py +++ b/pettingzoo/sisl/multiwalker/multiwalker.py @@ -157,7 +157,7 @@ class raw_env(AECEnv, EzPickle): def __init__(self, *args, **kwargs): EzPickle.__init__(self, *args, **kwargs) self.env = _env(*args, **kwargs) - + self.render_mode = self.env.render_mode self.agents = ["walker_" + str(r) for r in range(self.env.num_agents)] self.possible_agents = self.agents[:] self.agent_name_mapping = dict(zip(self.agents, list(range(self.num_agents)))) @@ -196,8 +196,8 @@ def reset(self, seed=None, return_info=False, options=None): def close(self): self.env.close() - def render(self, mode="human"): - return self.env.render(mode) + def render(self): + return self.env.render() def observe(self, agent): return self.env.observe(self.agent_name_mapping[agent]) diff --git a/pettingzoo/sisl/multiwalker/multiwalker_base.py b/pettingzoo/sisl/multiwalker/multiwalker_base.py index 6fdd85e43..7d2838089 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker_base.py +++ b/pettingzoo/sisl/multiwalker/multiwalker_base.py @@ -306,6 +306,7 @@ def __init__( remove_on_fall=True, terrain_length=TERRAIN_LENGTH, max_cycles=500, + render_mode=None, ): """Initializes the `MultiWalkerEnv` class. @@ -340,6 +341,7 @@ def __init__( self.last_dones = [False for _ in range(self.n_walkers)] self.last_obs = [None for _ in range(self.n_walkers)] self.max_cycles = max_cycles + self.render_mode = render_mode self.frames = 0 def get_param_values(self): @@ -523,6 +525,9 @@ def step(self, action, agent_id, is_last): self.last_dones = done self.frames = self.frames + 1 + if self.render_mode == "human": + self.render() + def get_last_rewards(self): return dict( zip( @@ -547,7 +552,7 @@ def observe(self, agent): o = np.array(o, dtype=np.float32) return o - def render(self, mode="human", close=False): + def render(self, close=False): if close: self.close() return @@ -688,9 +693,9 @@ def render(self, mode="human", close=False): self.surf = pygame.transform.flip(self.surf, False, True) self.screen.blit(self.surf, (-self.scroll * render_scale - offset, 0)) - if mode == "human": + if self.render_mode == "human": pygame.display.flip() - elif mode == "rgb_array": + elif self.render_mode == "rgb_array": return np.transpose( np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) ) diff --git a/pettingzoo/sisl/pursuit/pursuit.py b/pettingzoo/sisl/pursuit/pursuit.py index 3300ff7a8..41085481b 100755 --- a/pettingzoo/sisl/pursuit/pursuit.py +++ b/pettingzoo/sisl/pursuit/pursuit.py @@ -118,6 +118,7 @@ class raw_env(AECEnv, EzPickle): def __init__(self, *args, **kwargs): EzPickle.__init__(self, *args, **kwargs) self.env = _env(*args, **kwargs) + self.render_mode = kwargs.get("render_mode") pygame.init() self.agents = ["pursuer_" + str(a) for a in range(self.env.num_agents)] self.possible_agents = self.agents[:] @@ -152,9 +153,9 @@ def close(self): self.closed = True self.env.close() - def render(self, mode="human"): + def render(self): if not self.closed: - return self.env.render(mode) + return self.env.render() def step(self, action): if ( diff --git a/pettingzoo/sisl/pursuit/pursuit_base.py b/pettingzoo/sisl/pursuit/pursuit_base.py index bcf3b4b75..3b4f97221 100755 --- a/pettingzoo/sisl/pursuit/pursuit_base.py +++ b/pettingzoo/sisl/pursuit/pursuit_base.py @@ -1,6 +1,7 @@ from collections import defaultdict from typing import Optional +import gym import numpy as np import pygame from gym import spaces @@ -29,6 +30,7 @@ def __init__( catch_reward: float = 5.0, urgency_reward: float = -0.1, surround: bool = True, + render_mode=None, constraint_window: float = 1.0, ): """In evade pursuit a set of pursuers must 'tag' a set of evaders. @@ -142,6 +144,7 @@ def __init__( self.surround = surround + self.render_mode = render_mode self.constraint_window = constraint_window self.surround_mask = np.array([[-1, 0], [1, 0], [0, 1], [0, -1]]) @@ -273,6 +276,9 @@ def step(self, action, agent_id, is_last): self.local_ratio * local_val + (1 - self.local_ratio) * global_val ) + if self.render_mode == "human": + self.render() + def draw_model_state(self): # -1 is building pixel flag x_len, y_len = self.model_state[0].shape @@ -379,9 +385,15 @@ def draw_agent_counts(self): self.screen.blit(text, (pos_x, pos_y - self.pixel_scale // 2)) - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + if not self.renderOn: - if mode == "human": + if self.render_mode == "human": pygame.display.init() self.screen = pygame.display.set_mode( (self.pixel_scale * self.x_size, self.pixel_scale * self.y_size) @@ -403,11 +415,11 @@ def render(self, mode="human"): observation = pygame.surfarray.pixels3d(self.screen) new_observation = np.copy(observation) del observation - if mode == "human": + if self.render_mode == "human": pygame.display.flip() return ( np.transpose(new_observation, axes=(1, 0, 2)) - if mode == "rgb_array" + if self.render_mode == "rgb_array" else None ) diff --git a/pettingzoo/sisl/pursuit/utils/discrete_agent.py b/pettingzoo/sisl/pursuit/utils/discrete_agent.py index 2e84158d0..42f2443ad 100644 --- a/pettingzoo/sisl/pursuit/utils/discrete_agent.py +++ b/pettingzoo/sisl/pursuit/utils/discrete_agent.py @@ -83,6 +83,7 @@ def step(self, a): tpos += self.motion_range[a] x = tpos[0] y = tpos[1] + # check bounds if not self.inbounds(x, y): return cpos diff --git a/pettingzoo/sisl/waterworld/waterworld.py b/pettingzoo/sisl/waterworld/waterworld.py index fd7954faa..63dfbf939 100755 --- a/pettingzoo/sisl/waterworld/waterworld.py +++ b/pettingzoo/sisl/waterworld/waterworld.py @@ -209,8 +209,8 @@ def close(self): if self.has_reset: self.env.close() - def render(self, mode="human"): - return self.env.render(mode) + def render(self): + return self.env.render() def step(self, action): if ( diff --git a/pettingzoo/sisl/waterworld/waterworld_base.py b/pettingzoo/sisl/waterworld/waterworld_base.py index ad6d46aa2..7089c98d4 100755 --- a/pettingzoo/sisl/waterworld/waterworld_base.py +++ b/pettingzoo/sisl/waterworld/waterworld_base.py @@ -1,5 +1,6 @@ import math +import gym import numpy as np import pygame from gym import spaces @@ -160,6 +161,7 @@ def __init__( local_ratio=1.0, speed_features=True, max_cycles=500, + render_mode=None, ): raise AssertionError( "Please do not use Waterworld, at its current state it is incredibly buggy and the soundness of the environment is not guaranteed." @@ -246,6 +248,7 @@ def __init__( self.action_space = [agent.action_space for agent in self._pursuers] self.observation_space = [agent.observation_space for agent in self._pursuers] + self.render_mode = render_mode self.renderOn = False self.pixel_scale = 30 * 25 @@ -681,6 +684,8 @@ def move_objects(objects): self.frames += 1 + if self.render_mode == "human": + self.render() return self.observe(agent_id) def observe(self, agent): @@ -736,9 +741,15 @@ def draw_poisons(self): self.screen, color, center, self.pixel_scale * self.radius * 3 / 4 ) - def render(self, mode="human"): + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + return + if not self.renderOn: - if mode == "human": + if self.render_mode == "human": pygame.display.init() self.screen = pygame.display.set_mode( (self.pixel_scale, self.pixel_scale) @@ -756,10 +767,10 @@ def render(self, mode="human"): observation = pygame.surfarray.pixels3d(self.screen) new_observation = np.copy(observation) del observation - if mode == "human": + if self.render_mode == "human": pygame.display.flip() return ( np.transpose(new_observation, axes=(1, 0, 2)) - if mode == "rgb_array" + if self.render_mode == "rgb_array" else None ) diff --git a/pettingzoo/test/example_envs/generated_agents_env_v0.py b/pettingzoo/test/example_envs/generated_agents_env_v0.py index 01c45e7b6..72aca4596 100644 --- a/pettingzoo/test/example_envs/generated_agents_env_v0.py +++ b/pettingzoo/test/example_envs/generated_agents_env_v0.py @@ -20,7 +20,7 @@ class raw_env(AECEnv): metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"} - def __init__(self, max_cycles=100): + def __init__(self, max_cycles=100, render_mode=None): super().__init__() self._obs_spaces = {} self._act_spaces = {} @@ -28,6 +28,7 @@ def __init__(self, max_cycles=100): self._agent_counters = {} self.max_cycles = max_cycles self.seed() + self.render_mode = render_mode for i in range(3): self.add_type() @@ -117,9 +118,16 @@ def step(self, action): self._accumulate_rewards() self._deads_step_first() - - def render(self, mode="human"): - print(self.agents) + if self.render_mode == "human": + self.render() + + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + else: + print(self.agents) def close(self): pass diff --git a/pettingzoo/test/example_envs/generated_agents_parallel_v0.py b/pettingzoo/test/example_envs/generated_agents_parallel_v0.py index e40277093..aae7c11f4 100644 --- a/pettingzoo/test/example_envs/generated_agents_parallel_v0.py +++ b/pettingzoo/test/example_envs/generated_agents_parallel_v0.py @@ -24,7 +24,7 @@ class parallel_env(ParallelEnv): metadata = {"render_modes": ["human"], "name": "generated_agents_parallel_v0"} - def __init__(self, max_cycles=100): + def __init__(self, max_cycles=100, render_mode=None): super().__init__() self._obs_spaces = {} self._act_spaces = {} @@ -32,6 +32,7 @@ def __init__(self, max_cycles=100): self._agent_counters = {} self.max_cycles = max_cycles self.seed() + self.render_mode = render_mode for i in range(3): self.add_type() @@ -107,10 +108,18 @@ def step(self, actions): for agent in self.agents if not (all_truncations[agent] or all_terminations[agent]) ] + + if self.render_mode == "human": + self.render() return all_observes, all_rewards, all_terminations, all_truncations, all_infos - def render(self, mode="human"): - print(self.agents) + def render(self): + if self.render_mode is None: + gym.logger.WARN( + "You are calling render method without specifying any render mode." + ) + else: + print(self.agents) def close(self): pass diff --git a/pettingzoo/test/render_test.py b/pettingzoo/test/render_test.py index 045d304f8..507079088 100644 --- a/pettingzoo/test/render_test.py +++ b/pettingzoo/test/render_test.py @@ -3,7 +3,7 @@ import numpy as np -def collect_render_results(env, mode): +def collect_render_results(env): results = [] env.reset() @@ -18,20 +18,21 @@ def collect_render_results(env, mode): else: action = env.action_space(agent).sample() env.step(action) - render_result = env.render(mode=mode) + render_result = env.render() results.append(render_result) return results def render_test(env_fn, custom_tests={}): - env = env_fn() + env = env_fn(render_mode="human") render_modes = env.metadata.get("render_modes")[:] assert ( render_modes is not None ), "Environments that support rendering must define render_modes in metadata" for mode in render_modes: - render_results = collect_render_results(env, mode) + env = env_fn(render_mode=mode) + render_results = collect_render_results(env) for res in render_results: if mode in custom_tests.keys(): assert custom_tests[mode](res) @@ -49,4 +50,3 @@ def render_test(env_fn, custom_tests={}): if mode == "human": assert res is None env.close() - env = env_fn() diff --git a/pettingzoo/utils/conversions.py b/pettingzoo/utils/conversions.py index 7bddd6f52..ff3ea6068 100644 --- a/pettingzoo/utils/conversions.py +++ b/pettingzoo/utils/conversions.py @@ -173,8 +173,8 @@ def step(self, actions): self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos - def render(self, mode="human"): - return self.aec_env.render(mode) + def render(self): + return self.aec_env.render() def state(self): return self.aec_env.state() @@ -190,6 +190,8 @@ def __init__(self, parallel_env): self.metadata = {**parallel_env.metadata} self.metadata["is_parallelizable"] = True + self.render_mode = self.env.render_mode + try: self.possible_agents = parallel_env.possible_agents except AttributeError: @@ -321,8 +323,8 @@ def last(self, observe=True): self.infos[agent], ) - def render(self, mode="human"): - return self.env.render(mode) + def render(self): + return self.env.render() def close(self): self.env.close() @@ -426,8 +428,8 @@ def step(self, actions): self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos - def render(self, mode="human"): - return self.aec_env.render(mode) + def render(self): + return self.aec_env.render() def state(self): return self.aec_env.state() diff --git a/pettingzoo/utils/env.py b/pettingzoo/utils/env.py index 5b08b088e..9af7d24a3 100644 --- a/pettingzoo/utils/env.py +++ b/pettingzoo/utils/env.py @@ -86,10 +86,11 @@ def observe(self, agent: str) -> Optional[ObsType]: """ raise NotImplementedError - def render(self, mode: str = "human") -> None | np.ndarray | str: - """Displays a rendered frame from the environment, if supported. + def render(self) -> None | np.ndarray | str | list: + """Renders the environment as specified by self.render_mode. - Alternate render modes in the default environments are `'rgb_array'` + Render mode can be `human` to display a window. + Other render modes in the default environments are `'rgb_array'` which returns a numpy array and is supported by all environments outside of classic, and `'ansi'` which returns the strings printed (specific to classic environments). """ @@ -329,7 +330,7 @@ def step( """ raise NotImplementedError - def render(self, mode="human") -> None | np.ndarray | str: + def render(self) -> None | np.ndarray | str | List: """Displays a rendered frame from the environment, if supported. Alternate render modes in the default environments are `'rgb_array'` diff --git a/pettingzoo/utils/wrappers/base.py b/pettingzoo/utils/wrappers/base.py index 2416809a4..52a281ab6 100644 --- a/pettingzoo/utils/wrappers/base.py +++ b/pettingzoo/utils/wrappers/base.py @@ -39,6 +39,12 @@ def __init__(self, env): except AttributeError: pass + def __getattr__(self, name): + """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" + if name.startswith("_"): + raise AttributeError(f"accessing private attribute '{name}' is prohibited") + return getattr(self.env, name) + @property def observation_spaces(self): warnings.warn( @@ -78,8 +84,8 @@ def unwrapped(self): def close(self): self.env.close() - def render(self, mode="human"): - return self.env.render(mode) + def render(self): + return self.env.render() def reset(self, seed=None, return_info=False, options=None): self.env.reset(seed=seed, options=options) diff --git a/pettingzoo/utils/wrappers/base_parallel.py b/pettingzoo/utils/wrappers/base_parallel.py index 221d22dd4..4e51d61e9 100644 --- a/pettingzoo/utils/wrappers/base_parallel.py +++ b/pettingzoo/utils/wrappers/base_parallel.py @@ -40,8 +40,8 @@ def step(self, actions): self.agents = self.env.agents return res - def render(self, mode="human"): - return self.env.render(mode) + def render(self): + return self.env.render() def close(self): return self.env.close() diff --git a/pettingzoo/utils/wrappers/capture_stdout.py b/pettingzoo/utils/wrappers/capture_stdout.py index f93fe9a69..95a601500 100644 --- a/pettingzoo/utils/wrappers/capture_stdout.py +++ b/pettingzoo/utils/wrappers/capture_stdout.py @@ -4,19 +4,18 @@ class CaptureStdoutWrapper(BaseWrapper): def __init__(self, env): + assert ( + env.render_mode == "human" + ), f"CaptureStdoutWrapper works only with human rendering mode, but found {env.render_mode} instead." super().__init__(env) self.metadata["render_modes"].append("ansi") + self.render_mode = "ansi" - def render(self, mode="human"): - if mode == "ansi": - with capture_stdout() as stdout: - - super().render("human") - - val = stdout.getvalue() - return val - else: - return super().render(mode) + def render(self): + with capture_stdout() as stdout: + super().render() + val = stdout.getvalue() + return val def __str__(self): return str(self.env) diff --git a/pettingzoo/utils/wrappers/order_enforcing.py b/pettingzoo/utils/wrappers/order_enforcing.py index 2fdefd88b..6ef519b91 100644 --- a/pettingzoo/utils/wrappers/order_enforcing.py +++ b/pettingzoo/utils/wrappers/order_enforcing.py @@ -26,6 +26,8 @@ def __getattr__(self, value): """ if value == "unwrapped": return self.env.unwrapped + elif value == "render_mode": + return self.env.render_mode elif value == "possible_agents": EnvLogger.error_possible_agents_attribute_missing("possible_agents") elif value == "observation_spaces": @@ -55,12 +57,11 @@ def __getattr__(self, value): f"'{type(self).__name__}' object has no attribute '{value}'" ) - def render(self, mode="human"): + def render(self): if not self._has_reset: EnvLogger.error_render_before_reset() - assert mode in self.metadata["render_modes"] self._has_rendered = True - return super().render(mode) + return super().render() def step(self, action): if not self._has_reset: diff --git a/test/all_parameter_combs_test.py b/test/all_parameter_combs_test.py index dd8353779..517de97c4 100644 --- a/test/all_parameter_combs_test.py +++ b/test/all_parameter_combs_test.py @@ -253,7 +253,7 @@ def test_module(name, env_module, kwargs): if "atari/" not in name: seed_test(lambda: env_module.env(**kwargs), 50) - render_test(lambda: env_module.env(**kwargs)) + render_test(lambda render_mode: env_module.env(render_mode=render_mode, **kwargs)) if hasattr(env_module, "parallel_env"): par_env = env_module.parallel_env(**kwargs) try: diff --git a/test/pytest_runner_test.py b/test/pytest_runner_test.py index b95f781a5..aed873d23 100644 --- a/test/pytest_runner_test.py +++ b/test/pytest_runner_test.py @@ -16,7 +16,7 @@ @pytest.mark.parametrize(("name", "env_module"), list(all_environments.items())) def test_module(name, env_module): - _env = env_module.env() + _env = env_module.env(render_mode="human") assert str(_env) == os.path.basename(name) api_test(_env) if "classic/" not in name: @@ -34,7 +34,7 @@ def test_module(name, env_module): max_cycles_test(env_module) if ("butterfly/" in name) or ("mpe/" in name) or ("magent/" in name): - state_test(_env, env_module.parallel_env()) + state_test(env_module.env(), env_module.parallel_env()) # recreated_env = pickle.loads(pickle.dumps(_env)) # recreated_env.seed(42) diff --git a/test/unwrapped_test.py b/test/unwrapped_test.py index f2f11de52..19cf590a1 100644 --- a/test/unwrapped_test.py +++ b/test/unwrapped_test.py @@ -32,7 +32,7 @@ def discrete_observation(env, agents): @pytest.mark.parametrize(("name", "env_module"), list(all_environments.items())) def test_unwrapped(name, env_module): - env = env_module.env() + env = env_module.env(render_mode="human") base_env = env.unwrapped env.reset() diff --git a/tutorials/manual_pistonball_policy.py b/tutorials/manual_pistonball_policy.py index 99e3887fc..24abf46ce 100644 --- a/tutorials/manual_pistonball_policy.py +++ b/tutorials/manual_pistonball_policy.py @@ -73,6 +73,7 @@ def main(): ball_friction=0.3, ball_elasticity=1.5, max_cycles=125, + render_mode="rgb_array", ) total_reward = 0 obs_list = [] @@ -87,9 +88,7 @@ def main(): total_reward += rew i += 1 if i % (len(env.possible_agents) + 1) == 0: - obs_list.append( - np.transpose(env.render(mode="rgb_array"), axes=(1, 0, 2)) - ) + obs_list.append(np.transpose(env.render(), axes=(1, 0, 2))) env.close() print("average total reward: ", total_reward / NUM_RESETS) diff --git a/tutorials/render_rllib_pistonball.py b/tutorials/render_rllib_pistonball.py index 59b2716f4..664addd1c 100644 --- a/tutorials/render_rllib_pistonball.py +++ b/tutorials/render_rllib_pistonball.py @@ -43,6 +43,7 @@ def env_creator(): ball_friction=0.3, ball_elasticity=1.5, max_cycles=125, + render_mode="rgb_array", ) env = ss.color_reduction_v0(env, mode="B") env = ss.dtype_v0(env, "float32") @@ -85,7 +86,8 @@ def env_creator(): env.step(action) i += 1 if i % (len(env.possible_agents) + 1) == 0: - frame_list.append(Image.fromarray(env.render(mode="rgb_array"))) + img = Image.fromarray(env.render()) + frame_list.append(img) env.close()