Skip to content

Commit

Permalink
feat: support gymnasium style reset API (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Aug 12, 2023
1 parent 8702d79 commit 80c2c23
Show file tree
Hide file tree
Showing 18 changed files with 138 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.282
rev: v0.0.284
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
12 changes: 10 additions & 2 deletions omnisafe/adapter/offline_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,22 @@ def step(
"""
return self._env.step(actions)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
Args:
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
return self._env.reset()
return self._env.reset(seed=seed, options=options)

def evaluate(
self,
Expand Down
12 changes: 10 additions & 2 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,22 @@ def step(
"""
return self._env.step(action)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
Args:
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
return self._env.reset()
return self._env.reset(seed=seed, options=options)

def save(self) -> dict[str, torch.nn.Module]:
"""Save the important components of the environment.
Expand Down
12 changes: 10 additions & 2 deletions omnisafe/adapter/saute_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,25 @@ def _wrapper(
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env, device=self._device)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
.. note::
Additionally, the safety observation will be reset.
Args:
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
obs, info = self._env.reset()
obs, info = self._env.reset(seed=seed, options=options)
self._safety_obs = torch.ones(self._env.num_envs, 1).to(self._device)
obs = self._augment_obs(obs)
return obs, info
Expand Down
12 changes: 10 additions & 2 deletions omnisafe/adapter/simmer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,26 @@ def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:
budget_bound=self._upper_budget.cpu(),
)

def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
.. note::
Additionally, the safety observation will be reset. And the safety budget will be reset
to the value of current ``rel_safety_budget``.
Args:
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
obs, info = self._env.reset()
obs, info = self._env.reset(seed=seed, options=options)
self._safety_obs = self._rel_safety_budget * torch.ones(self._num_envs, 1).to(self._device)
obs = self._augment_obs(obs)
return obs, info
Expand Down
6 changes: 3 additions & 3 deletions omnisafe/algorithms/model_based/planner/arc.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ def _update_mean_var(
self,
elite_actions: torch.Tensor,
elite_values: torch.Tensor,
info: dict[str, int | float],
info: dict[str, float],
) -> tuple[torch.Tensor, torch.Tensor]: # pylint: disable-next=unused-argument
"""Update the mean and variance of the elite actions.
Args:
elite_actions (torch.Tensor): The elite actions.
elite_values (torch.Tensor): The elite values.
info (dict[str, int | float]): The dictionary containing the information of the elite values and actions.
info (dict[str, float]): The dictionary containing the information of the elite values and actions.
Returns:
new_mean: The new mean of the elite actions.
Expand Down Expand Up @@ -261,7 +261,7 @@ def _update_mean_var(
return new_mean, new_var

@torch.no_grad()
def output_action(self, state: torch.Tensor) -> tuple[torch.Tensor, dict[str, int | float]]:
def output_action(self, state: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
"""Output the action given the state.
Args:
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/algorithms/model_based/planner/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def _update_mean_var( # pylint: disable=unused-argument
self,
elite_actions: torch.Tensor,
elite_values: torch.Tensor,
info: dict[str, int | float],
info: dict[str, float],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Update the mean and variance of the elite actions.
Args:
elite_actions (torch.Tensor): The elite actions.
elite_values (torch.Tensor): The elite values.
info (dict[str, int | float]): The dictionary containing the information of the elite values and actions.
info (dict[str, float]): The dictionary containing the information of the elite values and actions.
Returns:
new_mean: The new mean of the elite actions.
Expand Down
8 changes: 4 additions & 4 deletions omnisafe/algorithms/model_based/planner/safe_arc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def _update_mean_var(
self,
elite_actions: torch.Tensor,
elite_values: torch.Tensor,
info: dict[str, int | float],
info: dict[str, float],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Update the mean and variance of the elite actions.
Args:
elite_actions (torch.Tensor): The elite actions.
elite_values (torch.Tensor): The elite values.
info (dict[str, int | float]): The dictionary containing the information of the elite values and actions.
info (dict[str, float]): The dictionary containing the information of the elite values and actions.
Returns:
new_mean: The new mean of the elite actions.
Expand Down Expand Up @@ -197,7 +197,7 @@ def _select_elites(
return elite_values, elite_actions, info

@torch.no_grad()
def output_action(self, state: torch.Tensor) -> tuple[torch.Tensor, dict[str, int | float]]:
def output_action(self, state: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
"""Output the action given the state.
Args:
Expand All @@ -217,7 +217,7 @@ def output_action(self, state: torch.Tensor) -> tuple[torch.Tensor, dict[str, in

current_iter = 0
actions_actor = self._act_from_actor(state)
info: dict[str, int | float] = {}
info: dict[str, float] = {}
while current_iter < self._num_iterations and last_var.max() > self._epsilon:
actions_gauss = self._act_from_last_gaus(last_mean=last_mean, last_var=last_var)
actions = torch.cat([actions_gauss, actions_actor], dim=1)
Expand Down
8 changes: 4 additions & 4 deletions omnisafe/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self._epoch: int = 0
self._first_row: bool = True
self._what_to_save: dict[str, Any] | None = None
self._data: dict[str, deque[int | float] | list[int | float]] = {}
self._data: dict[str, deque[float] | list[float]] = {}
self._headers_windows: dict[str, int | None] = {}
self._headers_minmax: dict[str, bool] = {}
self._headers_delta: dict[str, bool] = {}
self._current_row: dict[str, int | float] = {}
self._current_row: dict[str, float] = {}

if config is not None:
self.save_config(config)
Expand Down Expand Up @@ -257,7 +257,7 @@ def store(
The data stored in ``data`` will be updated by ``kwargs``.
Args:
data (dict[str, int | float | np.ndarray | torch.Tensor] or None, optional): The data to
data (dict[str, float | np.ndarray | torch.Tensor] or None, optional): The data to
be stored. Defaults to None.
Keyword Args:
Expand Down Expand Up @@ -340,7 +340,7 @@ def get_stats(
self,
key: str,
min_and_max: bool = False,
) -> tuple[int | float, ...]:
) -> tuple[float, ...]:
"""Get the statistics of the key.
Args:
Expand Down
21 changes: 16 additions & 5 deletions omnisafe/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,16 @@ def step(
"""

@abstractmethod
def reset(self, seed: int | None = None) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
Args:
seed (int or None): Seed for the environment. Defaults to None.
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
Expand Down Expand Up @@ -237,17 +242,23 @@ def step(
"""
return self._env.step(action)

def reset(self, seed: int | None = None) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
Args:
seed (int or None): Seed for the environment. Defaults to None.
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
return self._env.reset(seed)
return self._env.reset(seed=seed, options=options)

def set_seed(self, seed: int) -> None:
"""Set the seed for this env's random number generator(s).
Expand Down
11 changes: 8 additions & 3 deletions omnisafe/envs/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,22 @@ def step(

return obs, reward, cost, terminated, truncated, info

def reset(self, seed: int | None = None) -> tuple[torch.Tensor, dict]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict]:
"""Reset the environment.
Args:
seed (int, optional): Seed to reset the environment. Defaults to None.
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: Agent's observation of the current environment.
info: Auxiliary diagnostic information (helpful for debugging, and sometimes learning).
"""
obs, info = self._env.reset(seed=seed)
obs, info = self._env.reset(seed=seed, options=options)
return torch.as_tensor(obs, dtype=torch.float32, device=self._device), info

def set_seed(self, seed: int) -> None:
Expand Down
13 changes: 9 additions & 4 deletions omnisafe/envs/safety_gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,23 @@ def step(

return obs, reward, cost, terminated, truncated, info

def reset(self, seed: int | None = None) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment.
Args:
seed (int or None, optional): Seed to reset the environment.
Defaults to None.
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: Agent's observation of the current environment.
info: Some information logged by the environment.
"""
obs, info = self._env.reset(seed=seed)
obs, info = self._env.reset(seed=seed, options=options)
return torch.as_tensor(obs, dtype=torch.float32, device=self._device), info

def set_seed(self, seed: int) -> None:
Expand Down
12 changes: 9 additions & 3 deletions omnisafe/envs/safety_gymnasium_modelbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,17 +465,23 @@ def step(

return obs, reward, cost, terminated, truncated, info

def reset(self, seed: int | None = None) -> tuple[torch.Tensor, dict[str, Any]]:
def reset(
self,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment.
Args:
seed (int, optional): Seed to reset the environment. Defaults to None.
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
obs_original, info = self._env.reset(seed=seed)
obs_original, info = self._env.reset(seed=seed, options=options)
if self._task == 'Goal':
self.goal_position = self._env.task.goal.pos
self.robot_position = self._env.task.agent.pos
Expand Down
Loading

0 comments on commit 80c2c23

Please sign in to comment.