Skip to content

Commit

Permalink
chore: edit wrappers to have a separate flatten obs wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Jun 20, 2024
1 parent 06c4003 commit ab34f06
Show file tree
Hide file tree
Showing 20 changed files with 228 additions and 135 deletions.
5 changes: 5 additions & 0 deletions stoix/configs/env/gymnax/asterix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
5 changes: 5 additions & 0 deletions stoix/configs/env/gymnax/breakout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
4 changes: 2 additions & 2 deletions stoix/configs/env/gymnax/cartpole.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# ---Environment Configs---
env_name: gymnax
env_name: gymnax # Used for logging purposes and selection of the corresponding wrapper.

scenario:
name: CartPole-v1
task_name: cartpole
task_name: cartpole # For logging purposes.

kwargs: {}

Expand Down
5 changes: 5 additions & 0 deletions stoix/configs/env/gymnax/freeway.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
5 changes: 5 additions & 0 deletions stoix/configs/env/gymnax/space_invaders.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
8 changes: 6 additions & 2 deletions stoix/configs/env/jaxmarl/mabrax.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# ---Environment Configs---

env_name: MaBrax # Used for logging purposes and selection of the corresponding wrapper.

scenario:
name: ant_4x2 # [ant_4x2, halfcheetah_6x1, hopper_3x1, humanoid_9|8, walker2d_2x3]
task_name: ant_4x2 # For logging purposes.


add_agent_ids_to_state: False # Adds the agent IDs to the global state.
flatten_observation: True # Flattens the observations.
add_global_state : False # Adds the global state to the observations.

kwargs:
episode_length : 1000
action_repeat: 1
auto_reset: False
auto_reset: False # needs to be set to false - we apply auto-reset wrapper on top of this
backend: spring
homogenisation_method: max # Default is None. Options: [max, concat] if dimensions of observations and actions are not homogeneous across agents,
# "max" adds one-hot for the agent's ID. In this case, system.add_agent_id should be set to False.
Expand All @@ -22,3 +22,7 @@ kwargs:
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
6 changes: 5 additions & 1 deletion stoix/configs/env/jaxmarl/mpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ scenario:
task_name: mpe_simple

add_agent_ids_to_state: False # Adds the agent IDs to the global state.
flatten_observation: True # Flattens the observations.
add_global_state : False # Adds the global state to the observations.

kwargs:
Expand All @@ -15,3 +14,8 @@ kwargs:
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
6 changes: 5 additions & 1 deletion stoix/configs/env/jaxmarl/smax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ scenario:
task_name: 2s3z

add_agent_ids_to_state: False # Adds the agent IDs to the global state.
flatten_observation: True # Flattens the observations.
add_global_state : False # Adds the global state to the observations.

kwargs:
Expand All @@ -18,3 +17,8 @@ kwargs:
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
6 changes: 5 additions & 1 deletion stoix/configs/env/jumanji/2048.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---Environment Configs---
env_name: jumanji
observation_attribute : board
flatten_observation: True
multi_agent : False

scenario:
Expand All @@ -13,3 +12,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
6 changes: 5 additions & 1 deletion stoix/configs/env/jumanji/connector.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---Environment Configs---
env_name: jumanji
observation_attribute : grid
flatten_observation: True
multi_agent : True

scenario:
Expand All @@ -17,3 +16,8 @@ kwargs:
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
5 changes: 4 additions & 1 deletion stoix/configs/env/jumanji/rware.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---Environment Configs---
env_name: jumanji
observation_attribute : agents_view
flatten_observation: True
multi_agent : True

scenario:
Expand All @@ -21,3 +20,7 @@ kwargs:
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
13 changes: 8 additions & 5 deletions stoix/configs/env/jumanji/snake.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# ---Environment Configs---
env_name: jumanji
observation_attribute : grid
flatten_observation: True
multi_agent : False

scenario:
name: Snake-v1
task_name: snake

kwargs: {
num_rows: 6,
num_cols: 6,
}
kwargs:
num_rows: 6
num_cols: 6


# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
6 changes: 5 additions & 1 deletion stoix/configs/env/jumanji/sokoban.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---Environment Configs---
env_name: jumanji
observation_attribute : grid
flatten_observation: True
multi_agent : False

scenario:
Expand All @@ -13,3 +12,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
# wrapper:
# _target_: stoix.wrappers.transforms.FlattenObservationWrapper
5 changes: 4 additions & 1 deletion stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# ---Environment Configs---
env_name: minigrid
flatten_observation: True
scenario:
name: MiniGrid-DoorKey-5x5
task_name: minigrid_doorkey_5x5
Expand All @@ -10,3 +9,7 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
6 changes: 5 additions & 1 deletion stoix/configs/env/minigrid/minigrid_empty_6x6.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# ---Environment Configs---
env_name: minigrid
flatten_observation: True
scenario:
name: MiniGrid-Empty-6x6
task_name: minigrid_empty_6x6
Expand All @@ -10,3 +9,8 @@ kwargs: {}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
11 changes: 4 additions & 7 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.jumanji import MultiBoundedToBounded, MultiDiscreteToDiscrete
from stoix.wrappers.pgx import PGXWrapper
from stoix.wrappers.transforms import MultiBoundedToBounded, MultiDiscreteToDiscrete
from stoix.wrappers.xminigrid import XMiniGridWrapper


Expand Down Expand Up @@ -54,11 +54,10 @@ def make_jumanji_env(
env = jumanji.make(env_name, **env_kwargs)
eval_env = jumanji.make(env_name, **env_kwargs)
env, eval_env = JumanjiWrapper(
env, config.env.observation_attribute, config.env.flatten_observation
env, config.env.observation_attribute, config.env.multi_agent
), JumanjiWrapper(
eval_env,
config.env.observation_attribute,
config.env.flatten_observation,
config.env.multi_agent,
)

Expand Down Expand Up @@ -111,8 +110,8 @@ def make_xland_minigrid_env(env_name: str, config: DictConfig) -> Tuple[Environm

eval_env, eval_env_params = xminigrid.make(env_name, **config.env.kwargs)

env = XMiniGridWrapper(env, env_params, config.env.flatten_observation)
eval_env = XMiniGridWrapper(eval_env, eval_env_params, config.env.flatten_observation)
env = XMiniGridWrapper(env, env_params)
eval_env = XMiniGridWrapper(eval_env, eval_env_params)

env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)
Expand Down Expand Up @@ -170,13 +169,11 @@ def make_jaxmarl_env(
# Create jaxmarl envs.
env = _jaxmarl_wrappers.get(config.env.env_name, JaxMarlWrapper)(
jaxmarl.make(env_name, **kwargs),
config.env.flatten_observation,
config.env.add_global_state,
config.env.add_agent_ids_to_state,
)
eval_env = _jaxmarl_wrappers.get(config.env.env_name, JaxMarlWrapper)(
jaxmarl.make(env_name, **kwargs),
config.env.flatten_observation,
config.env.add_global_state,
config.env.add_agent_ids_to_state,
)
Expand Down
23 changes: 5 additions & 18 deletions stoix/wrappers/jaxmarl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class JaxMarlWrapper(Wrapper):
def __init__(
self,
env: MultiAgentEnv,
flatten_observation: bool,
has_global_state: bool,
add_agent_ids_to_state: bool = False,
timelimit: int = 1000,
Expand All @@ -166,7 +165,6 @@ def __init__(
Args:
- env: The JaxMarl environment to wrap.
- flatten_observation: Whether to flatten the observation.
- has_global_state: Whether the environment has global state.
- add_agent_ids_to_state: Whether to add the agent ids to the global state.
- timelimit: The time limit for each episode.
Expand All @@ -180,7 +178,6 @@ def __init__(

super().__init__(env)
self._env: MultiAgentEnv
self._flatten_observation = flatten_observation
self._timelimit = timelimit
self.agents = self._env.agents
self.num_agents = self._env.num_agents
Expand Down Expand Up @@ -231,8 +228,6 @@ def _create_observation(
"""Create an observation from the raw observation and environment state."""

obs = batchify(obs, self.agents)
if self._flatten_observation:
obs = obs.reshape(-1)

obs_data = {
"agent_view": obs,
Expand All @@ -245,15 +240,14 @@ def _create_observation(

if self.has_global_state:
obs_data["global_state"] = self.get_global_state(wrapped_env_state, obs)
if self._flatten_observation:
obs_data["global_state"] = obs_data["global_state"].reshape(-1)

return ObservationGlobalState(**obs_data)
else:
return Observation(**obs_data)

def observation_spec(self) -> specs.Spec:
agent_view = jaxmarl_space_to_jumanji_spec(
merge_space(self._env.observation_spaces, self._flatten_observation),
merge_space(self._env.observation_spaces),
)

action_mask = specs.BoundedArray(
Expand All @@ -265,8 +259,7 @@ def observation_spec(self) -> specs.Spec:

if self.has_global_state:
global_state_shape: Sequence[int] = (self.num_agents, self.state_size)
if self._flatten_observation:
global_state_shape = (self.num_agents * self.state_size,)

global_state = specs.Array(
global_state_shape,
agent_view.dtype,
Expand Down Expand Up @@ -330,14 +323,11 @@ class SmaxWrapper(JaxMarlWrapper):
def __init__(
self,
env: MultiAgentEnv,
flatten_observation: bool = True,
has_global_state: bool = False,
timelimit: int = 500,
add_agent_ids_to_state: bool = False,
):
super().__init__(
env, flatten_observation, has_global_state, add_agent_ids_to_state, timelimit
)
super().__init__(env, has_global_state, add_agent_ids_to_state, timelimit)
self._env: SMAX

@cached_property
Expand Down Expand Up @@ -367,14 +357,11 @@ class MabraxWrapper(JaxMarlWrapper):
def __init__(
self,
env: MABraxEnv,
flatten_observation: bool = True,
has_global_state: bool = False,
timelimit: int = 1000,
add_agent_ids_to_state: bool = False,
):
super().__init__(
env, flatten_observation, has_global_state, add_agent_ids_to_state, timelimit
)
super().__init__(env, has_global_state, add_agent_ids_to_state, timelimit)
self._env: MABraxEnv

@cached_property
Expand Down
Loading

0 comments on commit ab34f06

Please sign in to comment.