Skip to content

Commit

Permalink
chore(algorithms): rerender __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 22, 2022
1 parent 034f8c3 commit 35b3f23
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 62 deletions.
1 change: 1 addition & 0 deletions omnisafe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning."""

from omnisafe import algorithms
from omnisafe.algorithms import ALGORITHMS
from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent

Expand Down
89 changes: 27 additions & 62 deletions omnisafe/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,76 +17,41 @@
import itertools
from types import MappingProxyType

from omnisafe.algorithms import off_policy, on_policy

# Off-Policy Safe
from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag
from omnisafe.algorithms.off_policy.sac import SAC
from omnisafe.algorithms.off_policy.sac_lag import SACLag
from omnisafe.algorithms.off_policy.sddpg import SDDPG
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.algorithms.off_policy.td3_lag import TD3Lag
from omnisafe.algorithms.off_policy import DDPG, SAC, SDDPG, TD3, DDPGLag, SACLag, TD3Lag

# On-Policy Safe
from omnisafe.algorithms.on_policy.base.natural_pg import NaturalPG
from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient
from omnisafe.algorithms.on_policy.base.ppo import PPO
from omnisafe.algorithms.on_policy.base.trpo import TRPO
from omnisafe.algorithms.on_policy.early_terminated.ppo_early_terminated import PPOEarlyTerminated
from omnisafe.algorithms.on_policy.early_terminated.ppo_lag_early_terminated import (
from omnisafe.algorithms.on_policy import (
CPO,
CUP,
FOCOPS,
PCPO,
PDO,
PPO,
TRPO,
CPPOPid,
NaturalPG,
NPGLag,
PolicyGradient,
PPOEarlyTerminated,
PPOLag,
PPOLagEarlyTerminated,
PPOLagSaute,
PPOLagSimmerPid,
PPOLagSimmerQ,
PPOSaute,
PPOSimmerPid,
PPOSimmerQ,
TRPOLag,
TRPOPid,
)
from omnisafe.algorithms.on_policy.first_order.cup import CUP
from omnisafe.algorithms.on_policy.first_order.focops import FOCOPS
from omnisafe.algorithms.on_policy.naive_lagrange.npg_lag import NPGLag
from omnisafe.algorithms.on_policy.naive_lagrange.pdo import PDO
from omnisafe.algorithms.on_policy.naive_lagrange.ppo_lag import PPOLag
from omnisafe.algorithms.on_policy.naive_lagrange.trpo_lag import TRPOLag
from omnisafe.algorithms.on_policy.pid_lagrange.cppo_pid import CPPOPid
from omnisafe.algorithms.on_policy.pid_lagrange.trpo_pid import TRPOPid
from omnisafe.algorithms.on_policy.saute.ppo_lag_saute import PPOLagSaute
from omnisafe.algorithms.on_policy.saute.ppo_saute import PPOSaute
from omnisafe.algorithms.on_policy.second_order.cpo import CPO
from omnisafe.algorithms.on_policy.second_order.pcpo import PCPO
from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_pid import PPOLagSimmerPid
from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_q import PPOLagSimmerQ
from omnisafe.algorithms.on_policy.simmer.ppo_simmer_pid import PPOSimmerPid
from omnisafe.algorithms.on_policy.simmer.ppo_simmer_q import PPOSimmerQ


ALGORITHMS = {
'off-policy': (
'DDPG',
'DDPGLag',
'TD3',
'TD3Lag',
'SAC',
'SACLag',
'SDDPG',
),
'on-policy': (
'PolicyGradient',
'NaturalPG',
'TRPO',
'PPO',
'PDO',
'NPGLag',
'TRPOLag',
'PPOLag',
'CPPOPid',
'TRPOPid',
'FOCOPS',
'CUP',
'CPO',
'PCPO',
'PPOSimmerPid',
'PPOSimmerQ',
'PPOLagSimmerQ',
'PPOLagSimmerPid',
'PPOSaute',
'PPOLagSaute',
'PPOEarlyTerminated',
'PPOLagEarlyTerminated',
),
'off-policy': tuple(off_policy.__all__),
'on-policy': tuple(on_policy.__all__),
'model-based': (
'MBPPOLag',
'SafeLoop',
Expand Down
19 changes: 19 additions & 0 deletions omnisafe/algorithms/off_policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,22 @@
# limitations under the License.
# ==============================================================================
"""Off-policy algorithms."""

from omnisafe.algorithms.off_policy.ddpg import DDPG
from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag
from omnisafe.algorithms.off_policy.sac import SAC
from omnisafe.algorithms.off_policy.sac_lag import SACLag
from omnisafe.algorithms.off_policy.sddpg import SDDPG
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.algorithms.off_policy.td3_lag import TD3Lag


__all__ = [
'DDPG',
'DDPGLag',
'SAC',
'SACLag',
'SDDPG',
'TD3',
'TD3Lag',
]
36 changes: 36 additions & 0 deletions omnisafe/algorithms/on_policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,39 @@
# limitations under the License.
# ==============================================================================
"""On-policy algorithms."""

from omnisafe.algorithms.on_policy import (
base,
early_terminated,
first_order,
naive_lagrange,
pid_lagrange,
saute,
second_order,
simmer,
)
from omnisafe.algorithms.on_policy.base import PPO, TRPO, NaturalPG, PolicyGradient
from omnisafe.algorithms.on_policy.early_terminated import PPOEarlyTerminated, PPOLagEarlyTerminated
from omnisafe.algorithms.on_policy.first_order import CUP, FOCOPS
from omnisafe.algorithms.on_policy.naive_lagrange import PDO, NPGLag, PPOLag, TRPOLag
from omnisafe.algorithms.on_policy.pid_lagrange import CPPOPid, TRPOPid
from omnisafe.algorithms.on_policy.saute import PPOLagSaute, PPOSaute
from omnisafe.algorithms.on_policy.second_order import CPO, PCPO
from omnisafe.algorithms.on_policy.simmer import (
PPOLagSimmerPid,
PPOLagSimmerQ,
PPOSimmerPid,
PPOSimmerQ,
)


__all__ = [
*base.__all__,
*early_terminated.__all__,
*first_order.__all__,
*naive_lagrange.__all__,
*pid_lagrange.__all__,
*saute.__all__,
*second_order.__all__,
*simmer.__all__,
]
13 changes: 13 additions & 0 deletions omnisafe/algorithms/on_policy/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@
# limitations under the License.
# ==============================================================================
"""Basic Reinforcement Learning algorithms."""

from omnisafe.algorithms.on_policy.base.natural_pg import NaturalPG
from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient
from omnisafe.algorithms.on_policy.base.ppo import PPO
from omnisafe.algorithms.on_policy.base.trpo import TRPO


__all__ = [
'NaturalPG',
'PolicyGradient',
'PPO',
'TRPO',
]
11 changes: 11 additions & 0 deletions omnisafe/algorithms/on_policy/early_terminated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,14 @@
# limitations under the License.
# ==============================================================================
"""Early terminated algorithms."""

from omnisafe.algorithms.on_policy.early_terminated.ppo_early_terminated import PPOEarlyTerminated
from omnisafe.algorithms.on_policy.early_terminated.ppo_lag_early_terminated import (
PPOLagEarlyTerminated,
)


__all__ = [
'PPOEarlyTerminated',
'PPOLagEarlyTerminated',
]
9 changes: 9 additions & 0 deletions omnisafe/algorithms/on_policy/first_order/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@
# limitations under the License.
# ==============================================================================
"""The first order algorithms."""

from omnisafe.algorithms.on_policy.first_order.cup import CUP
from omnisafe.algorithms.on_policy.first_order.focops import FOCOPS


__all__ = [
'CUP',
'FOCOPS',
]
13 changes: 13 additions & 0 deletions omnisafe/algorithms/on_policy/naive_lagrange/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@
# limitations under the License.
# ==============================================================================
"""Naive Lagrange algorithms."""

from omnisafe.algorithms.on_policy.naive_lagrange.npg_lag import NPGLag
from omnisafe.algorithms.on_policy.naive_lagrange.pdo import PDO
from omnisafe.algorithms.on_policy.naive_lagrange.ppo_lag import PPOLag
from omnisafe.algorithms.on_policy.naive_lagrange.trpo_lag import TRPOLag


__all__ = [
'NPGLag',
'PDO',
'PPOLag',
'TRPOLag',
]
9 changes: 9 additions & 0 deletions omnisafe/algorithms/on_policy/pid_lagrange/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@
# limitations under the License.
# ==============================================================================
"""PID Lagrange algorithms."""

from omnisafe.algorithms.on_policy.pid_lagrange.cppo_pid import CPPOPid
from omnisafe.algorithms.on_policy.pid_lagrange.trpo_pid import TRPOPid


__all__ = [
'CPPOPid',
'TRPOPid',
]
9 changes: 9 additions & 0 deletions omnisafe/algorithms/on_policy/saute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@
# limitations under the License.
# ==============================================================================
"""Saute algorithms."""

from omnisafe.algorithms.on_policy.saute.ppo_lag_saute import PPOLagSaute
from omnisafe.algorithms.on_policy.saute.ppo_saute import PPOSaute


__all__ = [
'PPOLagSaute',
'PPOSaute',
]
9 changes: 9 additions & 0 deletions omnisafe/algorithms/on_policy/second_order/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@
# limitations under the License.
# ==============================================================================
"""Second order algorithms."""

from omnisafe.algorithms.on_policy.second_order.cpo import CPO
from omnisafe.algorithms.on_policy.second_order.pcpo import PCPO


__all__ = [
'CPO',
'PCPO',
]
13 changes: 13 additions & 0 deletions omnisafe/algorithms/on_policy/simmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@
# limitations under the License.
# ==============================================================================
"""Simmer algorithms."""

from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_pid import PPOLagSimmerPid
from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_q import PPOLagSimmerQ
from omnisafe.algorithms.on_policy.simmer.ppo_simmer_pid import PPOSimmerPid
from omnisafe.algorithms.on_policy.simmer.ppo_simmer_q import PPOSimmerQ


__all__ = [
'PPOLagSimmerPid',
'PPOLagSimmerQ',
'PPOSimmerPid',
'PPOSimmerQ',
]

0 comments on commit 35b3f23

Please sign in to comment.