Skip to content

Commit

Permalink
refactor: change architecture of omnisafe (PKU-Alignment#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
rockmagma02 authored and zmsn-2077 committed Mar 14, 2023
1 parent 0ebae0a commit 809525d
Show file tree
Hide file tree
Showing 134 changed files with 4,361 additions and 13,099 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ jobs:
run: |
make addlicense
# TODO: enable this when ready
# - name: mypy
# run: |
# make mypy
- name: mypy
run: |
make mypy
- name: Install dependencies
run: |
Expand Down
14 changes: 7 additions & 7 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ exclude-too-few-public-methods=
ignored-parents=

# Maximum number of arguments for function / method.
max-args=5
max-args=8

# Maximum number of attributes for a class (see R0902).
max-attributes=7
max-attributes=12

# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
Expand All @@ -301,22 +301,22 @@ max-bool-expr=5
max-branches=12

# Maximum number of locals for function / method body.
max-locals=15
max-locals=20

# Maximum number of parents for a class (see R0901).
max-parents=7
max-parents=12

# Maximum number of public methods for a class (see R0904).
max-public-methods=20

# Maximum number of return / yield for function / method body.
max-returns=6
max-returns=8

# Maximum number of statements in function / method body.
max-statements=50
max-statements=80

# Minimum number of public methods for a class (see R0903).
min-public-methods=2
min-public-methods=1


[EXCEPTIONS]
Expand Down
9 changes: 9 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,12 @@ noqa
hyperparameters
json
msg
env's
CMDP
api
moviepy
normalizer
Unsqueeze
Golub
logp
loc
4 changes: 3 additions & 1 deletion omnisafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from omnisafe import algorithms
from omnisafe.algorithms import ALGORITHMS
from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent
from omnisafe.evaluator import Evaluator

# from omnisafe.algorithms.env_wrapper import EnvWrapper as Env
from omnisafe.version import __version__


# from omnisafe.evaluator import Evaluator
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model-Based algorithms."""
"""Adapter for the environment and the algorithm."""

from omnisafe.algorithms.model_based.cap import CAP
from omnisafe.algorithms.model_based.mbppo_lag import MBPPOLag
from omnisafe.algorithms.model_based.safeloop import SafeLOOP


__all__ = [
'CAP',
'MBPPOLag',
'SafeLOOP',
]
from omnisafe.adapter.early_terminated_adapter import EarlyTerminatedAdapter
from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.adapter.onpolicy_adapter import OnPolicyAdapter
from omnisafe.adapter.saute_adapter import SauteAdapter
49 changes: 49 additions & 0 deletions omnisafe/adapter/early_terminated_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""OnPolicy Adapter for OmniSafe."""

from typing import Dict, Tuple

import torch

from omnisafe.adapter.onpolicy_adapter import OnPolicyAdapter
from omnisafe.utils.config import Config


class EarlyTerminatedAdapter(OnPolicyAdapter):
"""OnPolicy Adapter for OmniSafe."""

def __init__(self, env_id: str, num_envs: int, seed: int, cfgs: Config) -> None:
assert num_envs == 1, 'EarlyTerminatedAdapter only supports num_envs=1.'

super().__init__(env_id, num_envs, seed, cfgs)

self._cost_limit = cfgs.cost_limit
self._cost_logger = torch.zeros(self._env.num_envs)

def step(
self, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
next_obs, reward, cost, terminated, truncated, info = super().step(action)

self._cost_logger += info.get('original_cost', cost)

if self._cost_logger > self._cost_limit:
reward = torch.zeros(self._env.num_envs) # r_e = 0
terminated = torch.ones(self._env.num_envs)
next_obs, _ = self._env.reset()
self._cost_logger = torch.zeros(self._env.num_envs)

return next_obs, reward, cost, terminated, truncated, info
125 changes: 125 additions & 0 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Online Adapter for OmniSafe."""

from typing import Dict, Tuple

import torch

from omnisafe.envs.core import make, support_envs
from omnisafe.envs.wrapper import (
ActionScale,
AutoReset,
CostNormalize,
ObsNormalize,
RewardNormalize,
TimeLimit,
Unsqueeze,
)
from omnisafe.typing import OmnisafeSpace
from omnisafe.utils.config import Config


class OnlineAdapter:
"""Online Adapter for OmniSafe."""

def __init__( # pylint: disable=too-many-arguments
self,
env_id: str,
num_envs: int,
seed: int,
cfgs: Config,
) -> None:
assert env_id in support_envs(), f'Env {env_id} is not supported.'

self._env_id = env_id
self._env = make(env_id, num_envs=num_envs)
self._wrapper(
obs_normalize=cfgs.obs_normalize,
reward_normalize=cfgs.reward_normalize,
cost_normalize=cfgs.cost_normalize,
)
self._env.set_seed(seed)

self._cfgs = cfgs

def _wrapper(
self,
obs_normalize: bool = True,
reward_normalize: bool = True,
cost_normalize: bool = True,
):
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, time_limit=1000)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env)
if obs_normalize:
self._env = ObsNormalize(self._env)
if reward_normalize:
self._env = RewardNormalize(self._env)
if cost_normalize:
self._env = CostNormalize(self._env)
self._env = ActionScale(self._env, low=-1.0, high=1.0)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env)

@property
def action_space(self) -> OmnisafeSpace:
"""The action space of the environment.
Returns:
OmnisafeSpace: the action space.
"""
return self._env.action_space

@property
def observation_space(self) -> OmnisafeSpace:
"""The observation space of the environment.
Returns:
OmnisafeSpace: the observation space.
"""
return self._env.observation_space

def step(
self, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
"""Run one timestep of the environment's dynamics using the agent actions.
Args:
action (torch.Tensor): action.
Returns:
observation (torch.Tensor): agent's observation of the current environment.
reward (torch.Tensor): amount of reward returned after previous action.
cost (torch.Tensor): amount of cost returned after previous action.
terminated (torch.Tensor): whether the episode has ended, in which case further step()
calls will return undefined results.
truncated (torch.Tensor): whether the episode has been truncated due to a time limit.
info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning).
"""
return self._env.step(action)

def reset(self) -> Tuple[torch.Tensor, Dict]:
"""Resets the environment and returns an initial observation.
Args:
seed (Optional[int]): seed for the environment.
Returns:
observation (torch.Tensor): the initial observation of the space.
info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning).
"""
return self._env.reset()
Loading

0 comments on commit 809525d

Please sign in to comment.