Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove return info #205

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion supersuit/generic_wrappers/delay_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def delay_observations_v0(env, delay):
class DelayObsModifier(BaseModifier):
def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.delayer = Delayer(self.observation_space, delay)

def modify_obs(self, obs):
Expand Down
2 changes: 1 addition & 1 deletion supersuit/generic_wrappers/frame_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, env, num_frames):
check_transform_frameskip(num_frames)
self.num_frames = num_frames

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self.agents = self.env.agents[:]
self.terminations = make_defaultdict({agent: False for agent in self.agents})
Expand Down
4 changes: 2 additions & 2 deletions supersuit/generic_wrappers/frame_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def modify_obs_space(self, obs_space):
self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim)
return self.observation_space

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.stack = stack_init(self.old_obs_space, stack_size, stack_dim)

def modify_obs(self, obs):
Expand Down Expand Up @@ -67,7 +67,7 @@ def modify_obs_space(self, obs_space):
self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim)
return self.observation_space

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.stack = stack_init(self.old_obs_space, stack_size, stack_dim)
self.reset_flag = True

Expand Down
2 changes: 1 addition & 1 deletion supersuit/generic_wrappers/max_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def max_observation_v0(env, memory):
int(memory) # delay must be an int

class MaxObsModifier(BaseModifier):
def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.accumulator = Accumulator(self.observation_space, memory, np.maximum)

def modify_obs(self, obs):
Expand Down
4 changes: 2 additions & 2 deletions supersuit/generic_wrappers/nan_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class NanRandomModifier(BaseModifier):
def __init__(self):
super().__init__()

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.np_random, seed = gymnasium.utils.seeding.np_random(seed)

return super().reset(seed, return_info=return_info, options=options)
return super().reset(seed, options=options)

def modify_action(self, action):
if action is not None and np.isnan(action).any():
Expand Down
2 changes: 1 addition & 1 deletion supersuit/generic_wrappers/sticky_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class StickyActionsModifier(BaseModifier):
def __init__(self):
super().__init__()

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.np_random, _ = gymnasium.utils.seeding.np_random(seed)
self.old_action = None

Expand Down
2 changes: 1 addition & 1 deletion supersuit/generic_wrappers/utils/base_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ class BaseModifier:
def __init__(self):
pass

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
pass

def modify_obs(self, obs):
Expand Down
2 changes: 1 addition & 1 deletion supersuit/generic_wrappers/utils/shared_wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def add_modifiers(self, agents_list):
if self._cur_seed is not None:
self._cur_seed += 1

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self._cur_seed = seed
self._cur_options = options

Expand Down
15 changes: 4 additions & 11 deletions supersuit/lambda_wrappers/observation_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,10 @@ def step(self, action):
observation = self._modify_observation(observation)
return observation, rew, termination, truncation, info

def reset(self, seed=None, return_info=False, options=None):
if not return_info:
observation = self.env.reset(seed=seed, options=options)
observation = self._modify_observation(observation)
return observation
else:
observation, info = self.env.reset(
seed=seed, return_info=return_info, options=options
)
observation = self._modify_observation(observation)
return observation, info
def reset(self, seed=None, options=None):
observation = self.env.reset(seed=seed, options=options)
observation = self._modify_observation(observation)
return observation


observation_lambda_v0 = WrapperChooser(
Expand Down
2 changes: 1 addition & 1 deletion supersuit/lambda_wrappers/reward_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _check_wrapper_params(self):
def _modify_spaces(self):
pass

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self.rewards = {
agent: self._change_reward_fn(reward)
Expand Down
15 changes: 3 additions & 12 deletions supersuit/multiagent_wrappers/black_death.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@ def _check_valid_for_black_death(self):
space, gymnasium.spaces.Box
), f"observation sapces for black death must be Box spaces, is {space}"

def reset(self, seed=None, return_info=False, options=None):
if not return_info:
obss = self.env.reset(seed=seed, options=options)
else:
obss, infos = self.env.reset(
seed=seed, return_info=return_info, options=options
)
def reset(self, seed=None, options=None):
obss = self.env.reset(seed=seed, options=options)

self.agents = self.env.agents[:]
self._check_valid_for_black_death()
Expand All @@ -32,11 +27,7 @@ def reset(self, seed=None, return_info=False, options=None):
if agent not in obss
}

if not return_info:
return {**obss, **black_obs}
else:
black_infos = {agent: {} for agent in self.agents if agent not in obss}
return {**obss, **black_obs}, {**black_infos, **infos}
return {**obss, **black_obs}

def step(self, actions):
active_actions = {agent: actions[agent] for agent in self.env.agents}
Expand Down
2 changes: 1 addition & 1 deletion supersuit/utils/base_aec_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _modify_observation(self, agent, observation):
def _update_step(self, agent):
pass

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self._update_step(self.agent_selection)

Expand Down
39 changes: 9 additions & 30 deletions supersuit/vector/concat_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,19 @@ def __init__(self, vec_env_fns, obs_space=None, act_space=None):
tot_num_envs = sum(env.num_envs for env in vec_envs)
self.num_envs = tot_num_envs

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
_res_obs = []

if not return_info:
if seed is not None:
for i in range(len(self.vec_envs)):
_obs = self.vec_envs[i].reset(seed=seed + i, options=options)
_res_obs.append(_obs)
else:
_res_obs = [
vec_env.reset(seed=None, options=options)
for vec_env in self.vec_envs
]

return self.concat_obs(_res_obs)

if seed is not None:
for i in range(len(self.vec_envs)):
_obs = self.vec_envs[i].reset(seed=seed + i, options=options)
_res_obs.append(_obs)
else:
_res_info = []
if seed is not None:
for i in range(len(self.vec_envs)):
_obs, _info = self.vec_envs[i].reset(
seed=seed + i, return_info=return_info, options=options
)
_res_obs.append(_obs)
_res_info.append(_info)
else:
for vec_env in self.vec_envs:
_obs, _info = vec_env.reset(
seed=None, return_info=return_info, options=options
)
_res_obs.append(_obs)
_res_info.append(_info)
_res_obs = [
vec_env.reset(seed=None, options=options) for vec_env in self.vec_envs
]

return self.concat_obs(_res_obs), sum(_res_info, [])
return self.concat_obs(_res_obs)

def concat_obs(self, observations):
return concatenate(
Expand Down
16 changes: 4 additions & 12 deletions supersuit/vector/markov_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,10 @@ def step_async(self, actions):
def step_wait(self):
return self.step(self._saved_actions)

def reset(self, seed=None, return_info=False, options=None):
if not return_info:
_observations = self.par_env.reset(seed=seed, options=options)
observations = self.concat_obs(_observations)
return observations
else:
_observations, infos = self.par_env.reset(
seed=seed, return_info=return_info, options=options
)
infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents]
observations = self.concat_obs(_observations)
return observations, infs
def reset(self, seed=None, options=None):
_observations = self.par_env.reset(seed=seed, options=options)
observations = self.concat_obs(_observations)
return observations

def step(self, actions):
actions = list(iterate(self.action_space, actions))
Expand Down
14 changes: 4 additions & 10 deletions supersuit/vector/multiproc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,7 @@ def async_loop(
name, data = instr

if name == "reset":
if not data[1]:
observations = vec_env.reset(seed=data[0], options=data[2])
else:
observations, infos = vec_env.reset(
seed=data[0], return_info=data[1], options=data[2]
)
comp_infos = compress_info(infos)
observations = vec_env.reset(seed=data[0], options=data[1])

write_observations(vec_env, env_start_idx, shared_obs, observations)
shared_terms.np_arr[env_start_idx:env_end_idx] = False
Expand Down Expand Up @@ -180,12 +174,12 @@ def __init__(
assert num_envs == tot_num_envs
self.idx_starts = idx_starts

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
for i, pipe in enumerate(self.pipes):
if seed is not None:
pipe.send(("reset", (seed + i, return_info, options)))
pipe.send(("reset", (seed + i, options)))
else:
pipe.send(("reset", (seed, return_info, options)))
pipe.send(("reset", (seed, options)))

self._receive_info()

Expand Down
2 changes: 1 addition & 1 deletion supersuit/vector/sb3_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, venv):
self.observation_space = venv.observation_space
self.action_space = venv.action_space

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
if seed is not None:
self.seed(seed=seed)
return self.venv.reset()
Expand Down
2 changes: 1 addition & 1 deletion supersuit/vector/sb_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, venv):
self.observation_space = venv.observation_space
self.action_space = venv.action_space

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
if seed is not None:
self.seed(seed=seed)

Expand Down
2 changes: 1 addition & 1 deletion supersuit/vector/single_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, gym_env_fns, *args):
self.num_envs = 1
self.metadata = self.gym_env.metadata

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
return np.expand_dims(self.gym_env.reset(seed=seed, options=options), 0)

def step_async(self, actions):
Expand Down
2 changes: 1 addition & 1 deletion test/dummy_aec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def step(self, action, observe=True):
self._accumulate_rewards()
self._deads_step_first()

def reset(self, seed=None, return_info=False, options=None):
def reset(self, seed=None, options=None):
self.agents = self.possible_agents[:]
self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.reset()
Expand Down
7 changes: 2 additions & 5 deletions test/dummy_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,5 @@ def __init__(self, observation, observation_space, action_space):
def step(self, action):
return self._observation, 1, False, False, {}

def reset(self, seed=None, return_info=False, options=None):
if not return_info:
return self._observation
else:
return self._observation, {}
def reset(self, seed=None, options=None):
return self._observation
7 changes: 2 additions & 5 deletions test/parallel_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@ def step(self, actions):
self.infos,
)

def reset(self, seed=None, return_info=False, options=None):
if not return_info:
return self._observations
else:
return self._observations, self.infos
def reset(self, seed=None, options=None):
return self._observations

def close(self):
pass
Expand Down