diff --git a/stoix/configs/env/gymnax/asterix.yaml b/stoix/configs/env/gymnax/asterix.yaml index 743a272d..3238600c 100644 --- a/stoix/configs/env/gymnax/asterix.yaml +++ b/stoix/configs/env/gymnax/asterix.yaml @@ -10,3 +10,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/gymnax/breakout.yaml b/stoix/configs/env/gymnax/breakout.yaml index 0ce2e176..90c1e969 100644 --- a/stoix/configs/env/gymnax/breakout.yaml +++ b/stoix/configs/env/gymnax/breakout.yaml @@ -10,3 +10,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/gymnax/cartpole.yaml b/stoix/configs/env/gymnax/cartpole.yaml index b2e4555a..a8d50531 100644 --- a/stoix/configs/env/gymnax/cartpole.yaml +++ b/stoix/configs/env/gymnax/cartpole.yaml @@ -1,9 +1,9 @@ # ---Environment Configs--- -env_name: gymnax +env_name: gymnax # Used for logging purposes and selection of the corresponding wrapper. scenario: name: CartPole-v1 - task_name: cartpole + task_name: cartpole # For logging purposes. kwargs: {} diff --git a/stoix/configs/env/gymnax/freeway.yaml b/stoix/configs/env/gymnax/freeway.yaml index 949e99e2..edb323ae 100644 --- a/stoix/configs/env/gymnax/freeway.yaml +++ b/stoix/configs/env/gymnax/freeway.yaml @@ -10,3 +10,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/gymnax/space_invaders.yaml b/stoix/configs/env/gymnax/space_invaders.yaml index 2bf6b2d2..5308f128 100644 --- a/stoix/configs/env/gymnax/space_invaders.yaml +++ b/stoix/configs/env/gymnax/space_invaders.yaml @@ -10,3 +10,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jaxmarl/mabrax.yaml b/stoix/configs/env/jaxmarl/mabrax.yaml index c7db78fe..a547102d 100644 --- a/stoix/configs/env/jaxmarl/mabrax.yaml +++ b/stoix/configs/env/jaxmarl/mabrax.yaml @@ -1,19 +1,19 @@ # ---Environment Configs--- env_name: MaBrax # Used for logging purposes and selection of the corresponding wrapper. + scenario: name: ant_4x2 # [ant_4x2, halfcheetah_6x1, hopper_3x1, humanoid_9|8, walker2d_2x3] task_name: ant_4x2 # For logging purposes. add_agent_ids_to_state: False # Adds the agent IDs to the global state. -flatten_observation: True # Flattens the observations. add_global_state : False # Adds the global state to the observations. kwargs: episode_length : 1000 action_repeat: 1 - auto_reset: False + auto_reset: False # needs to be set to false - we apply auto-reset wrapper on top of this backend: spring homogenisation_method: max # Default is None. Options: [max, concat] if dimensions of observations and actions are not homogeneous across agents, # "max" adds one-hot for the agent's ID. In this case, system.add_agent_id should be set to False. @@ -22,3 +22,7 @@ kwargs: # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jaxmarl/mpe.yaml b/stoix/configs/env/jaxmarl/mpe.yaml index fb36cda1..5055b0fe 100644 --- a/stoix/configs/env/jaxmarl/mpe.yaml +++ b/stoix/configs/env/jaxmarl/mpe.yaml @@ -6,7 +6,6 @@ scenario: task_name: mpe_simple add_agent_ids_to_state: False # Adds the agent IDs to the global state. -flatten_observation: True # Flattens the observations. add_global_state : False # Adds the global state to the observations. kwargs: @@ -15,3 +14,8 @@ kwargs: # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jaxmarl/smax.yaml b/stoix/configs/env/jaxmarl/smax.yaml index 92bab683..0eeb775d 100644 --- a/stoix/configs/env/jaxmarl/smax.yaml +++ b/stoix/configs/env/jaxmarl/smax.yaml @@ -7,7 +7,6 @@ scenario: task_name: 2s3z add_agent_ids_to_state: False # Adds the agent IDs to the global state. -flatten_observation: True # Flattens the observations. add_global_state : False # Adds the global state to the observations. kwargs: @@ -18,3 +17,8 @@ kwargs: # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jumanji/2048.yaml b/stoix/configs/env/jumanji/2048.yaml index f04a792f..cd1d5572 100644 --- a/stoix/configs/env/jumanji/2048.yaml +++ b/stoix/configs/env/jumanji/2048.yaml @@ -1,7 +1,6 @@ # ---Environment Configs--- env_name: jumanji observation_attribute : board -flatten_observation: True multi_agent : False scenario: @@ -13,3 +12,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jumanji/connector.yaml b/stoix/configs/env/jumanji/connector.yaml index d44d8fbb..d47cf845 100644 --- a/stoix/configs/env/jumanji/connector.yaml +++ b/stoix/configs/env/jumanji/connector.yaml @@ -1,7 +1,6 @@ # ---Environment Configs--- env_name: jumanji observation_attribute : grid -flatten_observation: True multi_agent : True scenario: @@ -17,3 +16,8 @@ kwargs: # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jumanji/rware.yaml b/stoix/configs/env/jumanji/rware.yaml index 4cc65032..2e205ff8 100644 --- a/stoix/configs/env/jumanji/rware.yaml +++ b/stoix/configs/env/jumanji/rware.yaml @@ -1,7 +1,6 @@ # ---Environment Configs--- env_name: jumanji observation_attribute : agents_view -flatten_observation: True multi_agent : True scenario: @@ -21,3 +20,7 @@ kwargs: # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jumanji/snake.yaml b/stoix/configs/env/jumanji/snake.yaml index adbbeb9e..d6c639fe 100644 --- a/stoix/configs/env/jumanji/snake.yaml +++ b/stoix/configs/env/jumanji/snake.yaml @@ -1,18 +1,21 @@ # ---Environment Configs--- env_name: jumanji observation_attribute : grid -flatten_observation: True multi_agent : False scenario: name: Snake-v1 task_name: snake -kwargs: { - num_rows: 6, - num_cols: 6, -} +kwargs: + num_rows: 6 + num_cols: 6 + # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/jumanji/sokoban.yaml b/stoix/configs/env/jumanji/sokoban.yaml index 2153efc9..ccfb4861 100644 --- a/stoix/configs/env/jumanji/sokoban.yaml +++ b/stoix/configs/env/jumanji/sokoban.yaml @@ -1,7 +1,6 @@ # ---Environment Configs--- env_name: jumanji observation_attribute : grid -flatten_observation: True multi_agent : False scenario: @@ -13,3 +12,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +# wrapper: +# _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml b/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml index 1e6a1392..6c848151 100644 --- a/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml +++ b/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml @@ -1,6 +1,5 @@ # ---Environment Configs--- env_name: minigrid -flatten_observation: True scenario: name: MiniGrid-DoorKey-5x5 task_name: minigrid_doorkey_5x5 @@ -10,3 +9,7 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml b/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml index 69a5baf6..fd11326b 100644 --- a/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml +++ b/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml @@ -1,6 +1,5 @@ # ---Environment Configs--- env_name: minigrid -flatten_observation: True scenario: name: MiniGrid-Empty-6x6 task_name: minigrid_empty_6x6 @@ -10,3 +9,8 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index f686f86b..a8e37822 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -24,8 +24,8 @@ from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper -from stoix.wrappers.jumanji import MultiBoundedToBounded, MultiDiscreteToDiscrete from stoix.wrappers.pgx import PGXWrapper +from stoix.wrappers.transforms import MultiBoundedToBounded, MultiDiscreteToDiscrete from stoix.wrappers.xminigrid import XMiniGridWrapper @@ -54,11 +54,10 @@ def make_jumanji_env( env = jumanji.make(env_name, **env_kwargs) eval_env = jumanji.make(env_name, **env_kwargs) env, eval_env = JumanjiWrapper( - env, config.env.observation_attribute, config.env.flatten_observation + env, config.env.observation_attribute, config.env.multi_agent ), JumanjiWrapper( eval_env, config.env.observation_attribute, - config.env.flatten_observation, config.env.multi_agent, ) @@ -111,8 +110,8 @@ def make_xland_minigrid_env(env_name: str, config: DictConfig) -> Tuple[Environm eval_env, eval_env_params = xminigrid.make(env_name, **config.env.kwargs) - env = XMiniGridWrapper(env, env_params, config.env.flatten_observation) - eval_env = XMiniGridWrapper(eval_env, eval_env_params, config.env.flatten_observation) + env = XMiniGridWrapper(env, env_params) + eval_env = XMiniGridWrapper(eval_env, eval_env_params) env = AutoResetWrapper(env, next_obs_in_extras=True) env = RecordEpisodeMetrics(env) @@ -170,13 +169,11 @@ def make_jaxmarl_env( # Create jaxmarl envs. env = _jaxmarl_wrappers.get(config.env.env_name, JaxMarlWrapper)( jaxmarl.make(env_name, **kwargs), - config.env.flatten_observation, config.env.add_global_state, config.env.add_agent_ids_to_state, ) eval_env = _jaxmarl_wrappers.get(config.env.env_name, JaxMarlWrapper)( jaxmarl.make(env_name, **kwargs), - config.env.flatten_observation, config.env.add_global_state, config.env.add_agent_ids_to_state, ) diff --git a/stoix/wrappers/jaxmarl.py b/stoix/wrappers/jaxmarl.py index 98a9c5c9..3d26e05c 100644 --- a/stoix/wrappers/jaxmarl.py +++ b/stoix/wrappers/jaxmarl.py @@ -156,7 +156,6 @@ class JaxMarlWrapper(Wrapper): def __init__( self, env: MultiAgentEnv, - flatten_observation: bool, has_global_state: bool, add_agent_ids_to_state: bool = False, timelimit: int = 1000, @@ -166,7 +165,6 @@ def __init__( Args: - env: The JaxMarl environment to wrap. - - flatten_observation: Whether to flatten the observation. - has_global_state: Whether the environment has global state. - add_agent_ids_to_state: Whether to add the agent ids to the global state. - timelimit: The time limit for each episode. @@ -180,7 +178,6 @@ def __init__( super().__init__(env) self._env: MultiAgentEnv - self._flatten_observation = flatten_observation self._timelimit = timelimit self.agents = self._env.agents self.num_agents = self._env.num_agents @@ -231,8 +228,6 @@ def _create_observation( """Create an observation from the raw observation and environment state.""" obs = batchify(obs, self.agents) - if self._flatten_observation: - obs = obs.reshape(-1) obs_data = { "agent_view": obs, @@ -245,15 +240,14 @@ def _create_observation( if self.has_global_state: obs_data["global_state"] = self.get_global_state(wrapped_env_state, obs) - if self._flatten_observation: - obs_data["global_state"] = obs_data["global_state"].reshape(-1) + return ObservationGlobalState(**obs_data) else: return Observation(**obs_data) def observation_spec(self) -> specs.Spec: agent_view = jaxmarl_space_to_jumanji_spec( - merge_space(self._env.observation_spaces, self._flatten_observation), + merge_space(self._env.observation_spaces), ) action_mask = specs.BoundedArray( @@ -265,8 +259,7 @@ def observation_spec(self) -> specs.Spec: if self.has_global_state: global_state_shape: Sequence[int] = (self.num_agents, self.state_size) - if self._flatten_observation: - global_state_shape = (self.num_agents * self.state_size,) + global_state = specs.Array( global_state_shape, agent_view.dtype, @@ -330,14 +323,11 @@ class SmaxWrapper(JaxMarlWrapper): def __init__( self, env: MultiAgentEnv, - flatten_observation: bool = True, has_global_state: bool = False, timelimit: int = 500, add_agent_ids_to_state: bool = False, ): - super().__init__( - env, flatten_observation, has_global_state, add_agent_ids_to_state, timelimit - ) + super().__init__(env, has_global_state, add_agent_ids_to_state, timelimit) self._env: SMAX @cached_property @@ -367,14 +357,11 @@ class MabraxWrapper(JaxMarlWrapper): def __init__( self, env: MABraxEnv, - flatten_observation: bool = True, has_global_state: bool = False, timelimit: int = 1000, add_agent_ids_to_state: bool = False, ): - super().__init__( - env, flatten_observation, has_global_state, add_agent_ids_to_state, timelimit - ) + super().__init__(env, has_global_state, add_agent_ids_to_state, timelimit) self._env: MABraxEnv @cached_property diff --git a/stoix/wrappers/jumanji.py b/stoix/wrappers/jumanji.py index 63b8ec09..769b46f7 100644 --- a/stoix/wrappers/jumanji.py +++ b/stoix/wrappers/jumanji.py @@ -2,7 +2,6 @@ import chex import jax.numpy as jnp -import numpy as np from jumanji import specs from jumanji.env import Environment, State from jumanji.specs import Array, MultiDiscreteArray, Spec @@ -10,6 +9,7 @@ from jumanji.wrappers import MultiToSingleWrapper, Wrapper from stoix.base_types import Observation +from stoix.wrappers.transforms import MultiDiscreteToDiscrete class JumanjiWrapper(Wrapper): @@ -17,7 +17,6 @@ def __init__( self, env: Environment, observation_attribute: str, - flatten_observation: bool = False, multi_agent: bool = False, ) -> None: if isinstance(env.action_spec(), MultiDiscreteArray): @@ -26,106 +25,57 @@ def __init__( env = MultiToSingleWrapper(env) self._env = env - self._observation_attribute = observation_attribute - self._flatten_observation = flatten_observation - self._obs_shape = super().observation_spec().__dict__[self._observation_attribute].shape - if self._flatten_observation: - self._obs_shape = (np.prod(self._obs_shape),) self._legal_action_mask = jnp.ones((self.action_spec().num_values,), dtype=jnp.float32) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: state, timestep = self._env.reset(key) - obs = timestep.observation._asdict()[self._observation_attribute].astype(jnp.float32) + if self._observation_attribute: + agent_view = timestep.observation._asdict()[self._observation_attribute].astype( + jnp.float32 + ) + else: + agent_view = timestep.observation + obs = Observation(agent_view, self._legal_action_mask, state.step_count) + timestep_extras = timestep.extras + if not timestep_extras: + timestep_extras = {} timestep = timestep.replace( - observation=Observation( - obs.reshape(self._obs_shape), self._legal_action_mask, state.step_count - ), - extras={}, + observation=obs, + extras=timestep_extras, ) return state, timestep def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: state, timestep = self._env.step(state, action) - obs = timestep.observation._asdict()[self._observation_attribute].astype(jnp.float32) + if self._observation_attribute: + agent_view = timestep.observation._asdict()[self._observation_attribute].astype( + jnp.float32 + ) + else: + agent_view = timestep.observation + obs = Observation(agent_view, self._legal_action_mask, state.step_count) + timestep_extras = timestep.extras + if not timestep_extras: + timestep_extras = {} timestep = timestep.replace( - observation=Observation( - obs.reshape(self._obs_shape), self._legal_action_mask, state.step_count - ), - extras={}, + observation=obs, + extras=timestep_extras, ) return state, timestep def observation_spec(self) -> Spec: + if self._observation_attribute: + agent_view_spec = Array( + shape=self._env.observation_spec().__dict__[self._observation_attribute].shape, + dtype=jnp.float32, + ) + else: + agent_view_spec = self._env.observation_spec() return specs.Spec( Observation, "ObservationSpec", - agent_view=Array(shape=self._obs_shape, dtype=jnp.float32), - action_mask=Array(shape=(self.action_spec().num_values,), dtype=jnp.float32), + agent_view=agent_view_spec, + action_mask=Array(shape=self._legal_action_mask.shape, dtype=jnp.float32), step_count=Array(shape=(), dtype=jnp.int32), ) - - -class MultiDiscreteToDiscrete(Wrapper): - def __init__(self, env: Environment): - super().__init__(env) - self._action_spec_num_values = env.action_spec().num_values - - def apply_factorisation(self, x: chex.Array) -> chex.Array: - """Applies the factorisation to the given action.""" - action_components = [] - flat_action = x - n = self._action_spec_num_values.shape[0] - for i in range(n - 1, 0, -1): - flat_action, remainder = jnp.divmod(flat_action, self._action_spec_num_values[i]) - action_components.append(remainder) - action_components.append(flat_action) - action = jnp.stack( - list(reversed(action_components)), - axis=-1, - dtype=self._action_spec_num_values.dtype, - ) - return action - - def inverse_factorisation(self, y: chex.Array) -> chex.Array: - """Inverts the factorisation of the given action.""" - n = self._action_spec_num_values.shape[0] - action_components = jnp.split(y, n, axis=-1) - flat_action = action_components[0] - for i in range(1, n): - flat_action = self._action_spec_num_values[i] * flat_action + action_components[i] - return flat_action - - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: - action = self.apply_factorisation(action) - state, timestep = self._env.step(state, action) - return state, timestep - - def action_spec(self) -> specs.Spec: - """Returns the action spec of the environment.""" - original_action_spec = self._env.action_spec() - num_actions = int(np.prod(np.asarray(original_action_spec.num_values))) - return specs.DiscreteArray(num_actions, name="action") - - -class MultiBoundedToBounded(Wrapper): - def __init__(self, env: Environment): - super().__init__(env) - self._true_action_shape = env.action_spec().shape - - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: - action = action.reshape(self._true_action_shape) - state, timestep = self._env.step(state, action) - return state, timestep - - def action_spec(self) -> specs.Spec: - """Returns the action spec of the environment.""" - original_action_spec = self._env.action_spec() - size = int(np.prod(np.asarray(original_action_spec.shape))) - return specs.BoundedArray( - (size,), - minimum=original_action_spec.minimum, - maximum=original_action_spec.maximum, - dtype=original_action_spec.dtype, - name="action", - ) diff --git a/stoix/wrappers/transforms.py b/stoix/wrappers/transforms.py new file mode 100644 index 00000000..9e959bb9 --- /dev/null +++ b/stoix/wrappers/transforms.py @@ -0,0 +1,109 @@ +from typing import Tuple + +import chex +import jax.numpy as jnp +import numpy as np +from jumanji import specs +from jumanji.env import Environment, State +from jumanji.specs import Array, Spec +from jumanji.types import TimeStep +from jumanji.wrappers import Wrapper + +from stoix.base_types import Observation + + +class FlattenObservationWrapper(Wrapper): + """Simple wrapper that flattens the agent view observation.""" + + def __init__(self, env: Environment) -> None: + self._env = env + obs_shape = self._env.observation_spec().agent_view.shape + self._obs_shape = (np.prod(obs_shape),) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + state, timestep = self._env.reset(key) + agent_view = timestep.observation.agent_view.astype(jnp.float32) + agent_view = agent_view.reshape(self._obs_shape) + timestep = timestep.replace( + observation=timestep.observation._replace(agent_view=agent_view), + ) + return state, timestep + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + state, timestep = self._env.step(state, action) + agent_view = timestep.observation.agent_view.astype(jnp.float32) + agent_view = agent_view.reshape(self._obs_shape) + timestep = timestep.replace( + observation=timestep.observation._replace(agent_view=agent_view), + ) + return state, timestep + + def observation_spec(self) -> Spec: + return self._env.observation_spec().replace( + agent_view=Array(shape=self._obs_shape, dtype=jnp.float32) + ) + + +class MultiDiscreteToDiscrete(Wrapper): + def __init__(self, env: Environment): + super().__init__(env) + self._action_spec_num_values = env.action_spec().num_values + + def apply_factorisation(self, x: chex.Array) -> chex.Array: + """Applies the factorisation to the given action.""" + action_components = [] + flat_action = x + n = self._action_spec_num_values.shape[0] + for i in range(n - 1, 0, -1): + flat_action, remainder = jnp.divmod(flat_action, self._action_spec_num_values[i]) + action_components.append(remainder) + action_components.append(flat_action) + action = jnp.stack( + list(reversed(action_components)), + axis=-1, + dtype=self._action_spec_num_values.dtype, + ) + return action + + def inverse_factorisation(self, y: chex.Array) -> chex.Array: + """Inverts the factorisation of the given action.""" + n = self._action_spec_num_values.shape[0] + action_components = jnp.split(y, n, axis=-1) + flat_action = action_components[0] + for i in range(1, n): + flat_action = self._action_spec_num_values[i] * flat_action + action_components[i] + return flat_action + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: + action = self.apply_factorisation(action) + state, timestep = self._env.step(state, action) + return state, timestep + + def action_spec(self) -> specs.Spec: + """Returns the action spec of the environment.""" + original_action_spec = self._env.action_spec() + num_actions = int(np.prod(np.asarray(original_action_spec.num_values))) + return specs.DiscreteArray(num_actions, name="action") + + +class MultiBoundedToBounded(Wrapper): + def __init__(self, env: Environment): + super().__init__(env) + self._true_action_shape = env.action_spec().shape + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: + action = action.reshape(self._true_action_shape) + state, timestep = self._env.step(state, action) + return state, timestep + + def action_spec(self) -> specs.Spec: + """Returns the action spec of the environment.""" + original_action_spec = self._env.action_spec() + size = int(np.prod(np.asarray(original_action_spec.shape))) + return specs.BoundedArray( + (size,), + minimum=original_action_spec.minimum, + maximum=original_action_spec.maximum, + dtype=original_action_spec.dtype, + name="action", + ) diff --git a/stoix/wrappers/xminigrid.py b/stoix/wrappers/xminigrid.py index f0d94f9e..48ef465d 100644 --- a/stoix/wrappers/xminigrid.py +++ b/stoix/wrappers/xminigrid.py @@ -3,7 +3,6 @@ import chex import jax import jax.numpy as jnp -import numpy as np from jumanji import specs from jumanji.specs import Array, DiscreteArray, Spec from jumanji.types import TimeStep @@ -25,10 +24,9 @@ class XMiniGridEnvState: class XMiniGridWrapper(Wrapper): - def __init__(self, env: Environment, env_params: EnvParams, flatten_observation: bool = False): + def __init__(self, env: Environment, env_params: EnvParams): self._env = env self._env_params = env_params - self._flatten_observation = flatten_observation self._legal_action_mask = jnp.ones((self.action_spec().num_values,), dtype=jnp.float32) @@ -36,8 +34,6 @@ def reset(self, key: chex.PRNGKey) -> Tuple[XMiniGridEnvState, TimeStep]: key, reset_key = jax.random.split(key) minigrid_state_timestep = self._env.reset(self._env_params, reset_key) obs = minigrid_state_timestep.observation - if self._flatten_observation: - obs = obs.flatten() obs = Observation(obs, self._legal_action_mask, jnp.array(0)) timestep = TimeStep( observation=obs, @@ -54,8 +50,6 @@ def step(self, state: XMiniGridEnvState, action: chex.Array) -> Tuple[State, Tim self._env_params, state.minigrid_state_timestep, action ) obs = minigrid_state_timestep.observation - if self._flatten_observation: - obs = obs.flatten() obs = Observation( obs, self._legal_action_mask, @@ -78,8 +72,7 @@ def action_spec(self) -> Spec: def observation_spec(self) -> Spec: obs_shape = self._env.observation_shape(self._env_params) - if self._flatten_observation: - obs_shape = (np.prod(obs_shape),) + return specs.Spec( Observation, "ObservationSpec",