Skip to content

Commit

Permalink
feat: add new algorithms (#52)
Browse files Browse the repository at this point in the history
Co-authored-by: Xuehai Pan <[email protected]>
Co-authored-by: ruiyang sun <[email protected]>
  • Loading branch information
3 people authored Dec 23, 2022
1 parent b7baf82 commit b2e9847
Show file tree
Hide file tree
Showing 96 changed files with 5,247 additions and 694 deletions.
134 changes: 70 additions & 64 deletions README.md

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,44 @@ Binbin
Zhou
Pengfei
Yaodong
buf
Aivar
Sootla
Alexander
Cowen
Taher
Jafferjee
Ziyan
Wang
Mguni
Jun
Haitham
Ammar
Sun
Ziping
Xu
Meng
Fang
Zhenghao
Peng
Jiadong
Guo
Bo
lei
MDP
Bolei
Bou
Hao
Tuomas
Haarnoja
Aurick
Meger
Herke
Fujimoto
Lyapunov
Yinlam
Ofir
Nachum
Aleksandra
Duenez
Ghavamzadeh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Union

# import gymnasium
import gymnasium # pylint: disable=unused-import
import mujoco
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import RenderContextOffscreen, Viewer
Expand Down
14 changes: 10 additions & 4 deletions examples/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,24 @@
parser.add_argument(
'--algo',
type=str,
default='PPOLag',
help='Choose from: {PolicyGradient, PPO, PPOLag, NaturalPG,'
' TRPO, TRPOLag, PDO, NPGLag, CPO, PCPO, FOCOPS, CPPOPid,CUP',
metavar='ALGO',
default='PPOLagEarlyTerminated',
help='Algorithm to train',
choices=omnisafe.ALGORITHMS['all'],
)
parser.add_argument(
'--env-id',
type=str,
metavar='ENV',
default='SafetyPointGoal1-v0',
help='The name of test environment',
)
parser.add_argument(
'--parallel', default=1, type=int, help='Number of paralleled progress for calculations.'
'--parallel',
default=1,
type=int,
metavar='N',
help='Number of paralleled progress for calculations.',
)
args, unparsed_args = parser.parse_known_args()
keys = [k[2:] for k in unparsed_args[0::2]]
Expand Down
2 changes: 2 additions & 0 deletions omnisafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning."""

from omnisafe import algorithms
from omnisafe.algorithms import ALGORITHMS
from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent

# from omnisafe.algorithms.env_wrapper import EnvWrapper as Env
Expand Down
93 changes: 55 additions & 38 deletions omnisafe/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,59 @@
# ==============================================================================
"""Safe Reinforcement Learning algorithms."""

# Off Policy Safe
from omnisafe.algorithms.off_policy.ddpg import DDPG

# On Policy Safe
from omnisafe.algorithms.on_policy.cpo import CPO
from omnisafe.algorithms.on_policy.cppo_pid import CPPOPid
from omnisafe.algorithms.on_policy.cup import CUP
from omnisafe.algorithms.on_policy.focops import FOCOPS
from omnisafe.algorithms.on_policy.natural_pg import NaturalPG
from omnisafe.algorithms.on_policy.npg_lag import NPGLag
from omnisafe.algorithms.on_policy.pcpo import PCPO
from omnisafe.algorithms.on_policy.pdo import PDO
from omnisafe.algorithms.on_policy.policy_gradient import PolicyGradient
from omnisafe.algorithms.on_policy.ppo import PPO
from omnisafe.algorithms.on_policy.ppo_lag import PPOLag
from omnisafe.algorithms.on_policy.trpo import TRPO
from omnisafe.algorithms.on_policy.trpo_lag import TRPOLag


algo_type = {
'off-policy': ['DDPG'],
'on-policy': [
'CPO',
'FOCOPS',
'CPPOPid',
'FOCOPS',
'NaturalPG',
'NPGLag',
'PCPO',
'PDO',
'PolicyGradient',
'PPO',
'PPOLag',
'TRPO',
'TRPOLag',
'CUP',
],
'model-based': ['MBPPOLag', 'SafeLoop'],
import itertools
from types import MappingProxyType

from omnisafe.algorithms import off_policy, on_policy

# Off-Policy Safe
from omnisafe.algorithms.off_policy import DDPG, SAC, SDDPG, TD3, DDPGLag, SACLag, TD3Lag

# On-Policy Safe
from omnisafe.algorithms.on_policy import (
CPO,
CUP,
FOCOPS,
PCPO,
PDO,
PPO,
TRPO,
CPPOPid,
NaturalPG,
NPGLag,
PolicyGradient,
PPOEarlyTerminated,
PPOLag,
PPOLagEarlyTerminated,
PPOLagSaute,
PPOLagSimmerPid,
PPOLagSimmerQ,
PPOSaute,
PPOSimmerPid,
PPOSimmerQ,
TRPOLag,
TRPOPid,
)


ALGORITHMS = {
'off-policy': tuple(off_policy.__all__),
'on-policy': tuple(on_policy.__all__),
'model-based': (
'MBPPOLag',
'SafeLoop',
),
}

ALGORITHM2TYPE = {
algo: algo_type for algo_type, algorithms in ALGORITHMS.items() for algo in algorithms
}

__all__ = ALGORITHMS['all'] = tuple(itertools.chain.from_iterable(ALGORITHMS.values()))

assert len(ALGORITHM2TYPE) == len(__all__), 'Duplicate algorithm names found.'

ALGORITHMS = MappingProxyType(ALGORITHMS) # make this immutable
ALGORITHM2TYPE = MappingProxyType(ALGORITHM2TYPE) # make this immutable

del itertools, MappingProxyType
11 changes: 4 additions & 7 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import psutil

from omnisafe.algorithms import algo_type, registry
from omnisafe.algorithms import ALGORITHM2TYPE, registry
from omnisafe.utils import distributed_utils
from omnisafe.utils.config_utils import check_all_configs, recursive_update
from omnisafe.utils.tools import get_default_kwargs_yaml
Expand All @@ -46,13 +46,10 @@ def _init_checks(self):
assert (
isinstance(self.custom_cfgs, dict) or self.custom_cfgs is None
), 'custom_cfgs must be a dict!'
for key, value in algo_type.items():
if self.algo in value:
self.algo_type = key
break
if algo_type is None or algo_type == '':
self.algo_type = ALGORITHM2TYPE.get(self.algo, None)
if self.algo_type is None or self.algo_type == '':
raise ValueError(f'{self.algo} is not supported!')
if algo_type == 'off-policy':
if self.algo_type == 'off-policy':
assert self.parallel == 1, 'off-policy only support parallel==1!'

def learn(self):
Expand Down
34 changes: 34 additions & 0 deletions omnisafe/algorithms/off_policy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2022 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Off-policy algorithms."""

from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag
from omnisafe.algorithms.off_policy.sac import SAC
from omnisafe.algorithms.off_policy.sac_lag import SACLag
from omnisafe.algorithms.off_policy.sddpg import SDDPG
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.algorithms.off_policy.td3_lag import TD3Lag


__all__ = [
'DDPG',
'DDPGLag',
'SAC',
'SACLag',
'SDDPG',
'TD3',
'TD3Lag',
]
Loading

0 comments on commit b2e9847

Please sign in to comment.