From 5e2adaf8080be5d746f38b8f588a9493e073a9e9 Mon Sep 17 00:00:00 2001 From: snow-fox Date: Tue, 28 Feb 2023 17:08:15 +0000 Subject: [PATCH 1/2] remove return info --- .../generic_wrappers/delay_observations.py | 2 +- supersuit/generic_wrappers/frame_skip.py | 2 +- supersuit/generic_wrappers/frame_stack.py | 4 +- supersuit/generic_wrappers/max_observation.py | 2 +- supersuit/generic_wrappers/nan_wrappers.py | 4 +- supersuit/generic_wrappers/sticky_actions.py | 2 +- .../generic_wrappers/utils/base_modifier.py | 2 +- .../utils/shared_wrapper_util.py | 2 +- .../lambda_wrappers/observation_lambda.py | 15 ++----- supersuit/lambda_wrappers/reward_lambda.py | 2 +- supersuit/multiagent_wrappers/black_death.py | 15 ++----- supersuit/utils/base_aec_wrapper.py | 2 +- supersuit/vector/concat_vec_env.py | 40 +++++-------------- supersuit/vector/markov_vector_wrapper.py | 16 ++------ supersuit/vector/multiproc_vec.py | 14 ++----- supersuit/vector/sb3_vector_wrapper.py | 2 +- supersuit/vector/sb_vector_wrapper.py | 2 +- supersuit/vector/single_vec_env.py | 2 +- test/dummy_aec_env.py | 2 +- test/dummy_gym_env.py | 7 +--- test/parallel_env_test.py | 7 +--- 21 files changed, 45 insertions(+), 101 deletions(-) diff --git a/supersuit/generic_wrappers/delay_observations.py b/supersuit/generic_wrappers/delay_observations.py index c6b8179b..a523fb11 100644 --- a/supersuit/generic_wrappers/delay_observations.py +++ b/supersuit/generic_wrappers/delay_observations.py @@ -5,7 +5,7 @@ def delay_observations_v0(env, delay): class DelayObsModifier(BaseModifier): - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.delayer = Delayer(self.observation_space, delay) def modify_obs(self, obs): diff --git a/supersuit/generic_wrappers/frame_skip.py b/supersuit/generic_wrappers/frame_skip.py index 5c3c6a4b..a4af7f28 100644 --- a/supersuit/generic_wrappers/frame_skip.py +++ b/supersuit/generic_wrappers/frame_skip.py @@ -44,7 +44,7 @@ def __init__(self, env, num_frames): check_transform_frameskip(num_frames) self.num_frames = num_frames - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) self.agents = self.env.agents[:] self.terminations = make_defaultdict({agent: False for agent in self.agents}) diff --git a/supersuit/generic_wrappers/frame_stack.py b/supersuit/generic_wrappers/frame_stack.py index ffd59d03..7d061203 100644 --- a/supersuit/generic_wrappers/frame_stack.py +++ b/supersuit/generic_wrappers/frame_stack.py @@ -26,7 +26,7 @@ def modify_obs_space(self, obs_space): self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim) return self.observation_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.stack = stack_init(self.old_obs_space, stack_size, stack_dim) def modify_obs(self, obs): @@ -65,7 +65,7 @@ def modify_obs_space(self, obs_space): self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim) return self.observation_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.stack = stack_init(self.old_obs_space, stack_size, stack_dim) self.reset_flag = True diff --git a/supersuit/generic_wrappers/max_observation.py b/supersuit/generic_wrappers/max_observation.py index 133eab84..6b1f36ff 100644 --- a/supersuit/generic_wrappers/max_observation.py +++ b/supersuit/generic_wrappers/max_observation.py @@ -8,7 +8,7 @@ def max_observation_v0(env, memory): int(memory) # delay must be an int class MaxObsModifier(BaseModifier): - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.accumulator = Accumulator(self.observation_space, memory, np.maximum) def modify_obs(self, obs): diff --git a/supersuit/generic_wrappers/nan_wrappers.py b/supersuit/generic_wrappers/nan_wrappers.py index 00491928..8dec36cc 100644 --- a/supersuit/generic_wrappers/nan_wrappers.py +++ b/supersuit/generic_wrappers/nan_wrappers.py @@ -11,10 +11,10 @@ class NanRandomModifier(BaseModifier): def __init__(self): super().__init__() - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.np_random, seed = gymnasium.utils.seeding.np_random(seed) - return super().reset(seed, return_info=return_info, options=options) + return super().reset(seed, options=options) def modify_action(self, action): if action is not None and np.isnan(action).any(): diff --git a/supersuit/generic_wrappers/sticky_actions.py b/supersuit/generic_wrappers/sticky_actions.py index 3d0483a9..50a3a12b 100644 --- a/supersuit/generic_wrappers/sticky_actions.py +++ b/supersuit/generic_wrappers/sticky_actions.py @@ -10,7 +10,7 @@ class StickyActionsModifier(BaseModifier): def __init__(self): super().__init__() - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.np_random, _ = gymnasium.utils.seeding.np_random(seed) self.old_action = None diff --git a/supersuit/generic_wrappers/utils/base_modifier.py b/supersuit/generic_wrappers/utils/base_modifier.py index ae5b79a6..404a93ef 100644 --- a/supersuit/generic_wrappers/utils/base_modifier.py +++ b/supersuit/generic_wrappers/utils/base_modifier.py @@ -2,7 +2,7 @@ class BaseModifier: def __init__(self): pass - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): pass def modify_obs(self, obs): diff --git a/supersuit/generic_wrappers/utils/shared_wrapper_util.py b/supersuit/generic_wrappers/utils/shared_wrapper_util.py index db214cee..f102e16d 100644 --- a/supersuit/generic_wrappers/utils/shared_wrapper_util.py +++ b/supersuit/generic_wrappers/utils/shared_wrapper_util.py @@ -40,7 +40,7 @@ def add_modifiers(self, agents_list): if self._cur_seed is not None: self._cur_seed += 1 - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self._cur_seed = seed self._cur_options = options diff --git a/supersuit/lambda_wrappers/observation_lambda.py b/supersuit/lambda_wrappers/observation_lambda.py index f30a0b99..f99ff167 100644 --- a/supersuit/lambda_wrappers/observation_lambda.py +++ b/supersuit/lambda_wrappers/observation_lambda.py @@ -116,17 +116,10 @@ def step(self, action): observation = self._modify_observation(observation) return observation, rew, termination, truncation, info - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - observation = self.env.reset(seed=seed, options=options) - observation = self._modify_observation(observation) - return observation - else: - observation, info = self.env.reset( - seed=seed, return_info=return_info, options=options - ) - observation = self._modify_observation(observation) - return observation, info + def reset(self, seed=None, options=None): + observation = self.env.reset(seed=seed, options=options) + observation = self._modify_observation(observation) + return observation observation_lambda_v0 = WrapperChooser( diff --git a/supersuit/lambda_wrappers/reward_lambda.py b/supersuit/lambda_wrappers/reward_lambda.py index e4ab5e58..dd24c450 100644 --- a/supersuit/lambda_wrappers/reward_lambda.py +++ b/supersuit/lambda_wrappers/reward_lambda.py @@ -19,7 +19,7 @@ def _check_wrapper_params(self): def _modify_spaces(self): pass - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) self.rewards = { agent: self._change_reward_fn(reward) diff --git a/supersuit/multiagent_wrappers/black_death.py b/supersuit/multiagent_wrappers/black_death.py index 48946011..7672c926 100644 --- a/supersuit/multiagent_wrappers/black_death.py +++ b/supersuit/multiagent_wrappers/black_death.py @@ -15,13 +15,8 @@ def _check_valid_for_black_death(self): space, gymnasium.spaces.Box ), f"observation sapces for black death must be Box spaces, is {space}" - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - obss = self.env.reset(seed=seed, options=options) - else: - obss, infos = self.env.reset( - seed=seed, return_info=return_info, options=options - ) + def reset(self, seed=None, options=None): + obss = self.env.reset(seed=seed, options=options) self.agents = self.env.agents[:] self._check_valid_for_black_death() @@ -31,11 +26,7 @@ def reset(self, seed=None, return_info=False, options=None): if agent not in obss } - if not return_info: - return {**obss, **black_obs} - else: - black_infos = {agent: {} for agent in self.agents if agent not in obss} - return {**obss, **black_obs}, {**black_infos, **infos} + return {**obss, **black_obs} def step(self, actions): active_actions = {agent: actions[agent] for agent in self.env.agents} diff --git a/supersuit/utils/base_aec_wrapper.py b/supersuit/utils/base_aec_wrapper.py index 58995606..fc6303d1 100644 --- a/supersuit/utils/base_aec_wrapper.py +++ b/supersuit/utils/base_aec_wrapper.py @@ -27,7 +27,7 @@ def _modify_observation(self, agent, observation): def _update_step(self, agent): pass - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) self._update_step(self.agent_selection) diff --git a/supersuit/vector/concat_vec_env.py b/supersuit/vector/concat_vec_env.py index 76a6795d..c22e3ba6 100644 --- a/supersuit/vector/concat_vec_env.py +++ b/supersuit/vector/concat_vec_env.py @@ -29,40 +29,20 @@ def __init__(self, vec_env_fns, obs_space=None, act_space=None): tot_num_envs = sum(env.num_envs for env in vec_envs) self.num_envs = tot_num_envs - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): _res_obs = [] - if not return_info: - if seed is not None: - for i in range(len(self.vec_envs)): - _obs = self.vec_envs[i].reset(seed=seed + i, options=options) - _res_obs.append(_obs) - else: - _res_obs = [ - vec_env.reset(seed=None, options=options) - for vec_env in self.vec_envs - ] - - return self.concat_obs(_res_obs) - + if seed is not None: + for i in range(len(self.vec_envs)): + _obs = self.vec_envs[i].reset(seed=seed + i, options=options) + _res_obs.append(_obs) else: - _res_info = [] - if seed is not None: - for i in range(len(self.vec_envs)): - _obs, _info = self.vec_envs[i].reset( - seed=seed + i, return_info=return_info, options=options - ) - _res_obs.append(_obs) - _res_info.append(_info) - else: - for vec_env in self.vec_envs: - _obs, _info = vec_env.reset( - seed=None, return_info=return_info, options=options - ) - _res_obs.append(_obs) - _res_info.append(_info) + _res_obs = [ + vec_env.reset(seed=None, options=options) + for vec_env in self.vec_envs + ] - return self.concat_obs(_res_obs), sum(_res_info, []) + return self.concat_obs(_res_obs) def concat_obs(self, observations): return concatenate( diff --git a/supersuit/vector/markov_vector_wrapper.py b/supersuit/vector/markov_vector_wrapper.py index 015caad2..ccb4971b 100644 --- a/supersuit/vector/markov_vector_wrapper.py +++ b/supersuit/vector/markov_vector_wrapper.py @@ -51,18 +51,10 @@ def step_async(self, actions): def step_wait(self): return self.step(self._saved_actions) - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - _observations = self.par_env.reset(seed=seed, options=options) - observations = self.concat_obs(_observations) - return observations - else: - _observations, infos = self.par_env.reset( - seed=seed, return_info=return_info, options=options - ) - infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents] - observations = self.concat_obs(_observations) - return observations, infs + def reset(self, seed=None, options=None): + _observations = self.par_env.reset(seed=seed, options=options) + observations = self.concat_obs(_observations) + return observations def step(self, actions): actions = list(iterate(self.action_space, actions)) diff --git a/supersuit/vector/multiproc_vec.py b/supersuit/vector/multiproc_vec.py index da671622..06b01d59 100644 --- a/supersuit/vector/multiproc_vec.py +++ b/supersuit/vector/multiproc_vec.py @@ -69,13 +69,7 @@ def async_loop(vec_env_constr, inpt_p, pipe, shared_obs, shared_rews, shared_ter name, data = instr if name == "reset": - if not data[1]: - observations = vec_env.reset(seed=data[0], options=data[2]) - else: - observations, infos = vec_env.reset( - seed=data[0], return_info=data[1], options=data[2] - ) - comp_infos = compress_info(infos) + observations = vec_env.reset(seed=data[0], options=data[1]) write_observations(vec_env, env_start_idx, shared_obs, observations) shared_terms.np_arr[env_start_idx:env_end_idx] = False @@ -175,12 +169,12 @@ def __init__( assert num_envs == tot_num_envs self.idx_starts = idx_starts - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): for i, pipe in enumerate(self.pipes): if seed is not None: - pipe.send(("reset", (seed + i, return_info, options))) + pipe.send(("reset", (seed + i, options))) else: - pipe.send(("reset", (seed, return_info, options))) + pipe.send(("reset", (seed, options))) self._receive_info() diff --git a/supersuit/vector/sb3_vector_wrapper.py b/supersuit/vector/sb3_vector_wrapper.py index b2fb28a8..dd1cdc1d 100644 --- a/supersuit/vector/sb3_vector_wrapper.py +++ b/supersuit/vector/sb3_vector_wrapper.py @@ -8,7 +8,7 @@ def __init__(self, venv): self.observation_space = venv.observation_space self.action_space = venv.action_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): if seed is not None: self.seed(seed=seed) return self.venv.reset() diff --git a/supersuit/vector/sb_vector_wrapper.py b/supersuit/vector/sb_vector_wrapper.py index ef7d1d5a..f1d0f898 100644 --- a/supersuit/vector/sb_vector_wrapper.py +++ b/supersuit/vector/sb_vector_wrapper.py @@ -8,7 +8,7 @@ def __init__(self, venv): self.observation_space = venv.observation_space self.action_space = venv.action_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): if seed is not None: self.seed(seed=seed) diff --git a/supersuit/vector/single_vec_env.py b/supersuit/vector/single_vec_env.py index b1ce2fe9..81c6e735 100644 --- a/supersuit/vector/single_vec_env.py +++ b/supersuit/vector/single_vec_env.py @@ -11,7 +11,7 @@ def __init__(self, gym_env_fns, *args): self.num_envs = 1 self.metadata = self.gym_env.metadata - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): return np.expand_dims(self.gym_env.reset(seed=seed, options=options), 0) def step_async(self, actions): diff --git a/test/dummy_aec_env.py b/test/dummy_aec_env.py index 69203be7..a5dd81cf 100644 --- a/test/dummy_aec_env.py +++ b/test/dummy_aec_env.py @@ -43,7 +43,7 @@ def step(self, action, observe=True): self._accumulate_rewards() self._deads_step_first() - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.agents = self.possible_agents[:] self._agent_selector = agent_selector(self.agents) self.agent_selection = self._agent_selector.reset() diff --git a/test/dummy_gym_env.py b/test/dummy_gym_env.py index bcafcdeb..4e9c8458 100644 --- a/test/dummy_gym_env.py +++ b/test/dummy_gym_env.py @@ -12,8 +12,5 @@ def __init__(self, observation, observation_space, action_space): def step(self, action): return self._observation, 1, False, False, {} - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - return self._observation - else: - return self._observation, {} + def reset(self, seed=None, options=None): + return self._observation diff --git a/test/parallel_env_test.py b/test/parallel_env_test.py index e386878e..22656690 100644 --- a/test/parallel_env_test.py +++ b/test/parallel_env_test.py @@ -39,11 +39,8 @@ def step(self, actions): self.infos, ) - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - return self._observations - else: - return self._observations, self.infos + def reset(self, seed=None, options=None): + return self._observations def close(self): pass From 455c768cfaa35fb15f3168408214e68150474fc0 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 6 Mar 2023 16:22:04 +0000 Subject: [PATCH 2/2] black --- supersuit/vector/concat_vec_env.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/supersuit/vector/concat_vec_env.py b/supersuit/vector/concat_vec_env.py index 300ca840..8fad58de 100644 --- a/supersuit/vector/concat_vec_env.py +++ b/supersuit/vector/concat_vec_env.py @@ -39,8 +39,7 @@ def reset(self, seed=None, options=None): _res_obs.append(_obs) else: _res_obs = [ - vec_env.reset(seed=None, options=options) - for vec_env in self.vec_envs + vec_env.reset(seed=None, options=options) for vec_env in self.vec_envs ] return self.concat_obs(_res_obs)