From 6505baff3609793ef53e2abc03578c14a3b64952 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 9 Jul 2024 14:11:38 +0000 Subject: [PATCH 1/4] feat: add navix --- stoix/configs/env/navix/door_key_8x8.yaml | 15 +++ stoix/configs/env/navix/empty_5x5.yaml | 15 +++ .../doorkey_5x5.yaml} | 2 +- .../empty_6x6.yaml} | 2 +- stoix/utils/make_env.py | 30 ++++++ stoix/wrappers/navix.py | 97 +++++++++++++++++++ 6 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 stoix/configs/env/navix/door_key_8x8.yaml create mode 100644 stoix/configs/env/navix/empty_5x5.yaml rename stoix/configs/env/{minigrid/minigrid_doorkey_5x5.yaml => xland_minigrid/doorkey_5x5.yaml} (94%) rename stoix/configs/env/{minigrid/minigrid_empty_6x6.yaml => xland_minigrid/empty_6x6.yaml} (94%) create mode 100644 stoix/wrappers/navix.py diff --git a/stoix/configs/env/navix/door_key_8x8.yaml b/stoix/configs/env/navix/door_key_8x8.yaml new file mode 100644 index 00000000..67f30d2b --- /dev/null +++ b/stoix/configs/env/navix/door_key_8x8.yaml @@ -0,0 +1,15 @@ +# ---Environment Configs--- +env_name: navix +scenario: + name: Navix-DoorKey-8x8-v0 + task_name: navix-door_key-8x8-v0 + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/navix/empty_5x5.yaml b/stoix/configs/env/navix/empty_5x5.yaml new file mode 100644 index 00000000..91f47be4 --- /dev/null +++ b/stoix/configs/env/navix/empty_5x5.yaml @@ -0,0 +1,15 @@ +# ---Environment Configs--- +env_name: navix +scenario: + name: Navix-Empty-5x5-v0 + task_name: navix-dempty-5x5-v0 + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml b/stoix/configs/env/xland_minigrid/doorkey_5x5.yaml similarity index 94% rename from stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml rename to stoix/configs/env/xland_minigrid/doorkey_5x5.yaml index 6c848151..af2e555b 100644 --- a/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml +++ b/stoix/configs/env/xland_minigrid/doorkey_5x5.yaml @@ -1,5 +1,5 @@ # ---Environment Configs--- -env_name: minigrid +env_name: xland_minigrid scenario: name: MiniGrid-DoorKey-5x5 task_name: minigrid_doorkey_5x5 diff --git a/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml b/stoix/configs/env/xland_minigrid/empty_6x6.yaml similarity index 94% rename from stoix/configs/env/minigrid/minigrid_empty_6x6.yaml rename to stoix/configs/env/xland_minigrid/empty_6x6.yaml index fd11326b..566ed5cc 100644 --- a/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml +++ b/stoix/configs/env/xland_minigrid/empty_6x6.yaml @@ -1,5 +1,5 @@ # ---Environment Configs--- -env_name: minigrid +env_name: xland_minigrid scenario: name: MiniGrid-Empty-6x6 task_name: minigrid_empty_6x6 diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 7471aca5..67ebed25 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import jaxmarl import jumanji +import navix import pgx import popjym import xminigrid @@ -18,6 +19,7 @@ from jumanji.registration import _REGISTRY as JUMANJI_REGISTRY from jumanji.specs import BoundedArray, MultiDiscreteArray from jumanji.wrappers import AutoResetWrapper, MultiToSingleWrapper +from navix import registry as navix_registry from omegaconf import DictConfig from popjym.registration import REGISTERED_ENVS as POPJYM_REGISTRY from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY @@ -26,6 +28,7 @@ from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper +from stoix.wrappers.navix import NavixWrapper from stoix.wrappers.pgx import PGXWrapper from stoix.wrappers.transforms import ( AddStartFlagAndPrevAction, @@ -343,6 +346,31 @@ def make_popjym_env(env_name: str, config: DictConfig) -> Tuple[Environment, Env return env, eval_env +def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Environment]: + """ + Create Navix environments for training and evaluation. + + Args: + env_name (str): The name of the environment to create. + config (Dict): The configuration of the environment. + + Returns: + A tuple of the environments. + """ + + # Create envs. + env = navix.make(env_name, **config.env.kwargs) + eval_env = navix.make(env_name, **config.env.kwargs) + + env = NavixWrapper(env) + eval_env = NavixWrapper(eval_env) + + env = AutoResetWrapper(env, next_obs_in_extras=True) + env = RecordEpisodeMetrics(env) + + return env, eval_env + + def make(config: DictConfig) -> Tuple[Environment, Environment]: """ Create environments for training and evaluation.. @@ -373,6 +401,8 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: envs = make_pgx_env(env_name, config) elif env_name in POPJYM_REGISTRY: envs = make_popjym_env(env_name, config) + elif env_name in navix_registry(): + envs = make_navix_env(env_name, config) else: raise ValueError(f"{env_name} is not a supported environment.") diff --git a/stoix/wrappers/navix.py b/stoix/wrappers/navix.py new file mode 100644 index 00000000..5b548007 --- /dev/null +++ b/stoix/wrappers/navix.py @@ -0,0 +1,97 @@ +from typing import TYPE_CHECKING, Tuple + +import chex +import jax +import jax.numpy as jnp +from jumanji import specs +from jumanji.specs import Array, DiscreteArray, Spec +from jumanji.types import StepType, TimeStep, restart +from jumanji.wrappers import Wrapper +from navix.environments import Environment +from navix.environments import Timestep as NavixState + +from stoix.base_types import Observation + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + + +@dataclass +class NavixEnvState: + key: chex.PRNGKey + navix_state: NavixState + + +class NavixWrapper(Wrapper): + def __init__(self, env: Environment): + self._env = env + self._n_actions = len(self._env.action_set) + + def reset(self, key: chex.PRNGKey) -> Tuple[NavixEnvState, TimeStep]: + key, key_reset = jax.random.split(key) + navix_state = self._env.reset(key_reset) + agent_view = navix_state.observation.astype(float) + legal_action_mask = jnp.ones((self._n_actions,), dtype=float) + step_count = navix_state.t.astype(int) + obs = Observation(agent_view, legal_action_mask, step_count) + timestep = restart(obs, extras={}) + state = NavixEnvState(key=key, navix_state=navix_state) + return state, timestep + + def step(self, state: NavixEnvState, action: chex.Array) -> Tuple[NavixEnvState, TimeStep]: + key, key_step = jax.random.split(state.key) + + navix_state = self._env.step(state.navix_state, action) + + agent_view = navix_state.observation.astype(float) + legal_action_mask = jnp.ones((self._n_actions,), dtype=float) + step_count = navix_state.t.astype(int) + next_obs = Observation(agent_view, legal_action_mask, step_count) + + reward = navix_state.reward.astype(float) + terminal = navix_state.is_termination() + truncated = navix_state.is_truncation() + + discount = jnp.array(1.0 - terminal, dtype=float) + final_step = jnp.logical_or(terminal, truncated) + + timestep = TimeStep( + observation=next_obs, + reward=reward, + discount=discount, + step_type=jax.lax.select(final_step, StepType.LAST, StepType.MID), + extras={}, + ) + next_state = NavixEnvState(key=key_step, navix_state=navix_state) + return next_state, timestep + + def reward_spec(self) -> specs.Array: + return specs.Array(shape=(), dtype=float, name="reward") + + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount") + + def action_spec(self) -> Spec: + return DiscreteArray(num_values=self._n_actions) + + def observation_spec(self) -> Spec: + agent_view_shape = self._env.observation_space.shape + agent_view_min = self._env.observation_space.minimum + agent_view_max = self._env.observation_space.maximum + agent_view_spec = specs.BoundedArray( + shape=agent_view_shape, + dtype=float, + minimum=agent_view_min, + maximum=agent_view_max, + ) + action_mask_spec = Array(shape=(self._n_actions,), dtype=float) + + return specs.Spec( + Observation, + "ObservationSpec", + agent_view=agent_view_spec, + action_mask=action_mask_spec, + step_count=Array(shape=(), dtype=int), + ) From 4d06683ff554cb30b38ba203268bca3f70faea06 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 9 Jul 2024 14:16:05 +0000 Subject: [PATCH 2/4] chore: fix typo and add config --- stoix/configs/env/navix/empty_5x5.yaml | 2 +- stoix/configs/env/xland_minigrid/empty_5x5.yaml | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 stoix/configs/env/xland_minigrid/empty_5x5.yaml diff --git a/stoix/configs/env/navix/empty_5x5.yaml b/stoix/configs/env/navix/empty_5x5.yaml index 91f47be4..17818d03 100644 --- a/stoix/configs/env/navix/empty_5x5.yaml +++ b/stoix/configs/env/navix/empty_5x5.yaml @@ -2,7 +2,7 @@ env_name: navix scenario: name: Navix-Empty-5x5-v0 - task_name: navix-dempty-5x5-v0 + task_name: navix-empty-5x5-v0 kwargs: {} diff --git a/stoix/configs/env/xland_minigrid/empty_5x5.yaml b/stoix/configs/env/xland_minigrid/empty_5x5.yaml new file mode 100644 index 00000000..09970382 --- /dev/null +++ b/stoix/configs/env/xland_minigrid/empty_5x5.yaml @@ -0,0 +1,16 @@ +# ---Environment Configs--- +env_name: xland_minigrid +scenario: + name: MiniGrid-Empty-5x5 + task_name: minigrid_empty_5x5 + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper From 11db44c41a431a111275e350ed84ff988118071d Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 9 Jul 2024 14:35:07 +0000 Subject: [PATCH 3/4] chore: edit readme --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d313ab15..d4836932 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Stoix currently offers the following building blocks for Single-Agent RL researc - **Sampled Alpha/Mu-Zero** - [Paper](https://arxiv.org/abs/2104.06303) ### Environment Wrappers ๐Ÿฌ -Stoix offers wrappers for [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym] and even [JAXMarl][jaxmarl] (although using Centralised Controllers). +Stoix offers wrappers for [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym], [Navix][navix] and even [JAXMarl][jaxmarl] (although using Centralised Controllers). ### Statistically Robust Evaluation ๐Ÿงช Stoix natively supports logging to json files which adhere to the standard suggested by [Gorsane et al. (2022)][toward_standard_eval]. This enables easy downstream experiment plotting and aggregation using the tools found in the [MARL-eval][marl_eval] library. @@ -140,6 +140,12 @@ or if you wanted to do dueling C51, you could do: python stoix/systems/q_learning/ff_c51.py network=mlp_dueling_c51 ``` +## Important Considerations + +1. If your environment does not have a timestep limit or is not guaranteed to end through some game mechanic, then it is possible for the evaluation to seem as if it is hanging forever thereby stalling the training but in fact your agent is just so good _or bad_ that the episode never finishes. Keep this in mind if you are seeing this behaviour. One solution is to simply add a time step limit or potentially action masking. + +2. Due to the way Stoix is set up, you are not guaranteed to run for exactly the number of timesteps you set. A warning is given at the beginning of a run on the actual number of timesteps that will be run. This value will always be less than or equal to the specified sample budget. To get the exact number of transitions to run, ensure that the number of timesteps is divisible by the rollout length * total_num_envs and additionally ensure that the number of evaluations spaced out throughout training perfectly divide the number of updates to be performed. To see the exact calculation, see the file total_timestep_checker.py. This will give an indication of how the actual number of timesteps is calculated and how you can easily set it up to run the exact amount you desire. Its relatively trivial to do so but it is important to keep in mind. + ## Contributing ๐Ÿค Please read our [contributing docs](docs/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement and community guidelines. @@ -210,5 +216,6 @@ We would like to thank the authors and developers of [Mava](mava) as this was es [xminigrid]: https://github.com/corl-team/xland-minigrid/ [craftax]: https://github.com/MichaelTMatthews/Craftax [popjym]: https://github.com/FLAIROx/popjym +[navix]: https://github.com/epignatelli/navix Disclaimer: This is not an official InstaDeep product nor is any of the work putforward associated with InstaDeep in any official capacity. From 0d6b05faf5e04f2ebe7afccce2289ffb22aed6c5 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 9 Jul 2024 14:36:12 +0000 Subject: [PATCH 4/4] feat: add navix to requirements --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e564f9b9..a94587d7 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -14,6 +14,7 @@ jaxlib jaxmarl jumanji==1.0.0 mctx +navix neptune numpy omegaconf