diff --git a/supersuit/generic_wrappers/delay_observations.py b/supersuit/generic_wrappers/delay_observations.py index 93e0aee..866b700 100644 --- a/supersuit/generic_wrappers/delay_observations.py +++ b/supersuit/generic_wrappers/delay_observations.py @@ -6,7 +6,7 @@ def delay_observations_v0(env, delay): class DelayObsModifier(BaseModifier): - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.delayer = Delayer(self.observation_space, delay) def modify_obs(self, obs): diff --git a/supersuit/generic_wrappers/frame_skip.py b/supersuit/generic_wrappers/frame_skip.py index e98bce6..c92e37b 100644 --- a/supersuit/generic_wrappers/frame_skip.py +++ b/supersuit/generic_wrappers/frame_skip.py @@ -44,7 +44,7 @@ def __init__(self, env, num_frames): check_transform_frameskip(num_frames) self.num_frames = num_frames - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) self.agents = self.env.agents[:] self.terminations = make_defaultdict({agent: False for agent in self.agents}) diff --git a/supersuit/generic_wrappers/frame_stack.py b/supersuit/generic_wrappers/frame_stack.py index 0876b89..aadb7c8 100644 --- a/supersuit/generic_wrappers/frame_stack.py +++ b/supersuit/generic_wrappers/frame_stack.py @@ -28,7 +28,7 @@ def modify_obs_space(self, obs_space): self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim) return self.observation_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.stack = stack_init(self.old_obs_space, stack_size, stack_dim) def modify_obs(self, obs): @@ -67,7 +67,7 @@ def modify_obs_space(self, obs_space): self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim) return self.observation_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.stack = stack_init(self.old_obs_space, stack_size, stack_dim) self.reset_flag = True diff --git a/supersuit/generic_wrappers/max_observation.py b/supersuit/generic_wrappers/max_observation.py index b2e722d..811aba0 100644 --- a/supersuit/generic_wrappers/max_observation.py +++ b/supersuit/generic_wrappers/max_observation.py @@ -10,7 +10,7 @@ def max_observation_v0(env, memory): int(memory) # delay must be an int class MaxObsModifier(BaseModifier): - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.accumulator = Accumulator(self.observation_space, memory, np.maximum) def modify_obs(self, obs): diff --git a/supersuit/generic_wrappers/nan_wrappers.py b/supersuit/generic_wrappers/nan_wrappers.py index b2dc4b9..c1878c3 100644 --- a/supersuit/generic_wrappers/nan_wrappers.py +++ b/supersuit/generic_wrappers/nan_wrappers.py @@ -14,10 +14,10 @@ class NanRandomModifier(BaseModifier): def __init__(self): super().__init__() - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.np_random, seed = gymnasium.utils.seeding.np_random(seed) - return super().reset(seed, return_info=return_info, options=options) + return super().reset(seed, options=options) def modify_action(self, action): if action is not None and np.isnan(action).any(): diff --git a/supersuit/generic_wrappers/sticky_actions.py b/supersuit/generic_wrappers/sticky_actions.py index 986d42e..ad10c1b 100644 --- a/supersuit/generic_wrappers/sticky_actions.py +++ b/supersuit/generic_wrappers/sticky_actions.py @@ -11,7 +11,7 @@ class StickyActionsModifier(BaseModifier): def __init__(self): super().__init__() - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.np_random, _ = gymnasium.utils.seeding.np_random(seed) self.old_action = None diff --git a/supersuit/generic_wrappers/utils/base_modifier.py b/supersuit/generic_wrappers/utils/base_modifier.py index ae5b79a..404a93e 100644 --- a/supersuit/generic_wrappers/utils/base_modifier.py +++ b/supersuit/generic_wrappers/utils/base_modifier.py @@ -2,7 +2,7 @@ class BaseModifier: def __init__(self): pass - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): pass def modify_obs(self, obs): diff --git a/supersuit/generic_wrappers/utils/shared_wrapper_util.py b/supersuit/generic_wrappers/utils/shared_wrapper_util.py index 8945e4e..a96dc4e 100644 --- a/supersuit/generic_wrappers/utils/shared_wrapper_util.py +++ b/supersuit/generic_wrappers/utils/shared_wrapper_util.py @@ -42,7 +42,7 @@ def add_modifiers(self, agents_list): if self._cur_seed is not None: self._cur_seed += 1 - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self._cur_seed = seed self._cur_options = options diff --git a/supersuit/lambda_wrappers/observation_lambda.py b/supersuit/lambda_wrappers/observation_lambda.py index 704880f..2bf9ea2 100644 --- a/supersuit/lambda_wrappers/observation_lambda.py +++ b/supersuit/lambda_wrappers/observation_lambda.py @@ -118,17 +118,10 @@ def step(self, action): observation = self._modify_observation(observation) return observation, rew, termination, truncation, info - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - observation = self.env.reset(seed=seed, options=options) - observation = self._modify_observation(observation) - return observation - else: - observation, info = self.env.reset( - seed=seed, return_info=return_info, options=options - ) - observation = self._modify_observation(observation) - return observation, info + def reset(self, seed=None, options=None): + observation = self.env.reset(seed=seed, options=options) + observation = self._modify_observation(observation) + return observation observation_lambda_v0 = WrapperChooser( diff --git a/supersuit/lambda_wrappers/reward_lambda.py b/supersuit/lambda_wrappers/reward_lambda.py index 788668f..6f9e424 100644 --- a/supersuit/lambda_wrappers/reward_lambda.py +++ b/supersuit/lambda_wrappers/reward_lambda.py @@ -20,7 +20,7 @@ def _check_wrapper_params(self): def _modify_spaces(self): pass - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) self.rewards = { agent: self._change_reward_fn(reward) diff --git a/supersuit/multiagent_wrappers/black_death.py b/supersuit/multiagent_wrappers/black_death.py index cc5f39b..b2a251c 100644 --- a/supersuit/multiagent_wrappers/black_death.py +++ b/supersuit/multiagent_wrappers/black_death.py @@ -16,13 +16,8 @@ def _check_valid_for_black_death(self): space, gymnasium.spaces.Box ), f"observation sapces for black death must be Box spaces, is {space}" - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - obss = self.env.reset(seed=seed, options=options) - else: - obss, infos = self.env.reset( - seed=seed, return_info=return_info, options=options - ) + def reset(self, seed=None, options=None): + obss = self.env.reset(seed=seed, options=options) self.agents = self.env.agents[:] self._check_valid_for_black_death() @@ -32,11 +27,7 @@ def reset(self, seed=None, return_info=False, options=None): if agent not in obss } - if not return_info: - return {**obss, **black_obs} - else: - black_infos = {agent: {} for agent in self.agents if agent not in obss} - return {**obss, **black_obs}, {**black_infos, **infos} + return {**obss, **black_obs} def step(self, actions): active_actions = {agent: actions[agent] for agent in self.env.agents} diff --git a/supersuit/utils/base_aec_wrapper.py b/supersuit/utils/base_aec_wrapper.py index 5899560..fc6303d 100644 --- a/supersuit/utils/base_aec_wrapper.py +++ b/supersuit/utils/base_aec_wrapper.py @@ -27,7 +27,7 @@ def _modify_observation(self, agent, observation): def _update_step(self, agent): pass - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): super().reset(seed=seed, options=options) self._update_step(self.agent_selection) diff --git a/supersuit/vector/concat_vec_env.py b/supersuit/vector/concat_vec_env.py index 8ee4177..8fad58d 100644 --- a/supersuit/vector/concat_vec_env.py +++ b/supersuit/vector/concat_vec_env.py @@ -30,40 +30,19 @@ def __init__(self, vec_env_fns, obs_space=None, act_space=None): tot_num_envs = sum(env.num_envs for env in vec_envs) self.num_envs = tot_num_envs - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): _res_obs = [] - if not return_info: - if seed is not None: - for i in range(len(self.vec_envs)): - _obs = self.vec_envs[i].reset(seed=seed + i, options=options) - _res_obs.append(_obs) - else: - _res_obs = [ - vec_env.reset(seed=None, options=options) - for vec_env in self.vec_envs - ] - - return self.concat_obs(_res_obs) - + if seed is not None: + for i in range(len(self.vec_envs)): + _obs = self.vec_envs[i].reset(seed=seed + i, options=options) + _res_obs.append(_obs) else: - _res_info = [] - if seed is not None: - for i in range(len(self.vec_envs)): - _obs, _info = self.vec_envs[i].reset( - seed=seed + i, return_info=return_info, options=options - ) - _res_obs.append(_obs) - _res_info.append(_info) - else: - for vec_env in self.vec_envs: - _obs, _info = vec_env.reset( - seed=None, return_info=return_info, options=options - ) - _res_obs.append(_obs) - _res_info.append(_info) + _res_obs = [ + vec_env.reset(seed=None, options=options) for vec_env in self.vec_envs + ] - return self.concat_obs(_res_obs), sum(_res_info, []) + return self.concat_obs(_res_obs) def concat_obs(self, observations): return concatenate( diff --git a/supersuit/vector/markov_vector_wrapper.py b/supersuit/vector/markov_vector_wrapper.py index 219377a..0b9eac1 100644 --- a/supersuit/vector/markov_vector_wrapper.py +++ b/supersuit/vector/markov_vector_wrapper.py @@ -51,18 +51,10 @@ def step_async(self, actions): def step_wait(self): return self.step(self._saved_actions) - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - _observations = self.par_env.reset(seed=seed, options=options) - observations = self.concat_obs(_observations) - return observations - else: - _observations, infos = self.par_env.reset( - seed=seed, return_info=return_info, options=options - ) - infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents] - observations = self.concat_obs(_observations) - return observations, infs + def reset(self, seed=None, options=None): + _observations = self.par_env.reset(seed=seed, options=options) + observations = self.concat_obs(_observations) + return observations def step(self, actions): actions = list(iterate(self.action_space, actions)) diff --git a/supersuit/vector/multiproc_vec.py b/supersuit/vector/multiproc_vec.py index 4ca1fd6..93107f2 100644 --- a/supersuit/vector/multiproc_vec.py +++ b/supersuit/vector/multiproc_vec.py @@ -72,13 +72,7 @@ def async_loop( name, data = instr if name == "reset": - if not data[1]: - observations = vec_env.reset(seed=data[0], options=data[2]) - else: - observations, infos = vec_env.reset( - seed=data[0], return_info=data[1], options=data[2] - ) - comp_infos = compress_info(infos) + observations = vec_env.reset(seed=data[0], options=data[1]) write_observations(vec_env, env_start_idx, shared_obs, observations) shared_terms.np_arr[env_start_idx:env_end_idx] = False @@ -180,12 +174,12 @@ def __init__( assert num_envs == tot_num_envs self.idx_starts = idx_starts - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): for i, pipe in enumerate(self.pipes): if seed is not None: - pipe.send(("reset", (seed + i, return_info, options))) + pipe.send(("reset", (seed + i, options))) else: - pipe.send(("reset", (seed, return_info, options))) + pipe.send(("reset", (seed, options))) self._receive_info() diff --git a/supersuit/vector/sb3_vector_wrapper.py b/supersuit/vector/sb3_vector_wrapper.py index 630b284..60a7a19 100644 --- a/supersuit/vector/sb3_vector_wrapper.py +++ b/supersuit/vector/sb3_vector_wrapper.py @@ -8,7 +8,7 @@ def __init__(self, venv): self.observation_space = venv.observation_space self.action_space = venv.action_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): if seed is not None: self.seed(seed=seed) return self.venv.reset() diff --git a/supersuit/vector/sb_vector_wrapper.py b/supersuit/vector/sb_vector_wrapper.py index ef7d1d5..f1d0f89 100644 --- a/supersuit/vector/sb_vector_wrapper.py +++ b/supersuit/vector/sb_vector_wrapper.py @@ -8,7 +8,7 @@ def __init__(self, venv): self.observation_space = venv.observation_space self.action_space = venv.action_space - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): if seed is not None: self.seed(seed=seed) diff --git a/supersuit/vector/single_vec_env.py b/supersuit/vector/single_vec_env.py index ff76f38..90e8284 100644 --- a/supersuit/vector/single_vec_env.py +++ b/supersuit/vector/single_vec_env.py @@ -11,7 +11,7 @@ def __init__(self, gym_env_fns, *args): self.num_envs = 1 self.metadata = self.gym_env.metadata - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): return np.expand_dims(self.gym_env.reset(seed=seed, options=options), 0) def step_async(self, actions): diff --git a/test/dummy_aec_env.py b/test/dummy_aec_env.py index 031a8fb..ae2f209 100644 --- a/test/dummy_aec_env.py +++ b/test/dummy_aec_env.py @@ -42,7 +42,7 @@ def step(self, action, observe=True): self._accumulate_rewards() self._deads_step_first() - def reset(self, seed=None, return_info=False, options=None): + def reset(self, seed=None, options=None): self.agents = self.possible_agents[:] self._agent_selector = agent_selector(self.agents) self.agent_selection = self._agent_selector.reset() diff --git a/test/dummy_gym_env.py b/test/dummy_gym_env.py index 8961a18..c4e3917 100644 --- a/test/dummy_gym_env.py +++ b/test/dummy_gym_env.py @@ -11,8 +11,5 @@ def __init__(self, observation, observation_space, action_space): def step(self, action): return self._observation, 1, False, False, {} - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - return self._observation - else: - return self._observation, {} + def reset(self, seed=None, options=None): + return self._observation diff --git a/test/parallel_env_test.py b/test/parallel_env_test.py index 3dfb9f8..5fdf2d5 100644 --- a/test/parallel_env_test.py +++ b/test/parallel_env_test.py @@ -40,11 +40,8 @@ def step(self, actions): self.infos, ) - def reset(self, seed=None, return_info=False, options=None): - if not return_info: - return self._observations - else: - return self._observations, self.infos + def reset(self, seed=None, options=None): + return self._observations def close(self): pass