Skip to content

Commit

Permalink
Merge pull request #3 from rockmagma02/pr/Gaiejj/137
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj authored Mar 8, 2023
2 parents 2042d07 + 25f9dd2 commit 648ba4c
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 137 deletions.
16 changes: 1 addition & 15 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ def __init__( # pylint: disable=too-many-arguments
num_envs: int,
seed: int,
cfgs: Config,
**env_kwargs: Dict,
) -> None:
assert env_id in support_envs(), f'Env {env_id} is not supported.'

self._env_id = env_id
self._env = make(env_id, num_envs=num_envs, **env_kwargs)
self._env = make(env_id, num_envs=num_envs)
self._wrapper(
obs_normalize=cfgs.algo_cfgs.obs_normalize,
reward_normalize=cfgs.algo_cfgs.reward_normalize,
Expand Down Expand Up @@ -76,19 +75,6 @@ def _wrapper(
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env)

def load(self, obs_normlizer_dict): # pylint: disable=unused-argument
"""Load the environment.
Args:
obs_normlizer_dict (Dict): the dict of the observation normalizer.
"""
assert self._cfgs.algo_cfgs.obs_normalize, 'The observation normalizer is not loaded.'
self._env.load(obs_normlizer_dict)

def render(self) -> None:
"""Render the environment."""
return self._env.render()

@property
def action_space(self) -> OmnisafeSpace:
"""The action space of the environment.
Expand Down
6 changes: 5 additions & 1 deletion omnisafe/common/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Implementation of Vector Buffer."""

from typing import Tuple
from typing import Any, Mapping, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -108,3 +108,7 @@ def _push(self, raw_data: torch.Tensor) -> None:
self._var = self._sumsq / (self._count - 1)
self._std = torch.sqrt(self._var)
self._std = torch.max(self._std, 1e-2 * torch.ones_like(self._std))

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self._first = False
return super().load_state_dict(state_dict, strict)
10 changes: 10 additions & 0 deletions omnisafe/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class CMDP(ABC):
_support_envs: List[str]
_action_space: OmnisafeSpace
_observation_space: OmnisafeSpace
_metadata: Dict[str, Any]

_num_envs: int
_time_limit: Optional[int] = None
Expand Down Expand Up @@ -86,6 +87,15 @@ def observation_space(self) -> OmnisafeSpace:
"""
return self._observation_space

@property
def metadata(self) -> Dict[str, Any]:
"""The metadata of the environment.
Returns:
Dict[str, Any]: the metadata.
"""
return self._metadata

@property
def num_envs(self) -> int:
"""The parallel environments.
Expand Down
1 change: 1 addition & 0 deletions omnisafe/envs/safety_gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None:
self._observation_space = self._env.observation_space

self._num_envs = num_envs
self._metadata = self._env.metadata

def step(
self, action: torch.Tensor
Expand Down
4 changes: 0 additions & 4 deletions omnisafe/envs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def save(self) -> Dict[str, torch.nn.Module]:
saved['obs_normalizer'] = self._obs_normalizer
return saved

def load(self, obs_normalizer_dict: dict) -> None:
"""Load the normalizer."""
self._obs_normalizer.load_state_dict(obs_normalizer_dict)


class RewardNormalize(Wrapper):
"""Normalize the reward.
Expand Down
Loading

0 comments on commit 648ba4c

Please sign in to comment.