diff --git a/omnisafe/__init__.py b/omnisafe/__init__.py index 35d79b41a..f252acd93 100644 --- a/omnisafe/__init__.py +++ b/omnisafe/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning.""" +from omnisafe import algorithms from omnisafe.algorithms import ALGORITHMS from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent diff --git a/omnisafe/algorithms/__init__.py b/omnisafe/algorithms/__init__.py index c733055ba..497b58bbd 100644 --- a/omnisafe/algorithms/__init__.py +++ b/omnisafe/algorithms/__init__.py @@ -17,76 +17,41 @@ import itertools from types import MappingProxyType +from omnisafe.algorithms import off_policy, on_policy + # Off-Policy Safe -from omnisafe.algorithms.off_policy.ddpg import DDPG -from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag -from omnisafe.algorithms.off_policy.sac import SAC -from omnisafe.algorithms.off_policy.sac_lag import SACLag -from omnisafe.algorithms.off_policy.sddpg import SDDPG -from omnisafe.algorithms.off_policy.td3 import TD3 -from omnisafe.algorithms.off_policy.td3_lag import TD3Lag +from omnisafe.algorithms.off_policy import DDPG, SAC, SDDPG, TD3, DDPGLag, SACLag, TD3Lag # On-Policy Safe -from omnisafe.algorithms.on_policy.base.natural_pg import NaturalPG -from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient -from omnisafe.algorithms.on_policy.base.ppo import PPO -from omnisafe.algorithms.on_policy.base.trpo import TRPO -from omnisafe.algorithms.on_policy.early_terminated.ppo_early_terminated import PPOEarlyTerminated -from omnisafe.algorithms.on_policy.early_terminated.ppo_lag_early_terminated import ( +from omnisafe.algorithms.on_policy import ( + CPO, + CUP, + FOCOPS, + PCPO, + PDO, + PPO, + TRPO, + CPPOPid, + NaturalPG, + NPGLag, + PolicyGradient, + PPOEarlyTerminated, + PPOLag, PPOLagEarlyTerminated, + PPOLagSaute, + PPOLagSimmerPid, + PPOLagSimmerQ, + PPOSaute, + PPOSimmerPid, + PPOSimmerQ, + TRPOLag, + TRPOPid, ) -from omnisafe.algorithms.on_policy.first_order.cup import CUP -from omnisafe.algorithms.on_policy.first_order.focops import FOCOPS -from omnisafe.algorithms.on_policy.naive_lagrange.npg_lag import NPGLag -from omnisafe.algorithms.on_policy.naive_lagrange.pdo import PDO -from omnisafe.algorithms.on_policy.naive_lagrange.ppo_lag import PPOLag -from omnisafe.algorithms.on_policy.naive_lagrange.trpo_lag import TRPOLag -from omnisafe.algorithms.on_policy.pid_lagrange.cppo_pid import CPPOPid -from omnisafe.algorithms.on_policy.pid_lagrange.trpo_pid import TRPOPid -from omnisafe.algorithms.on_policy.saute.ppo_lag_saute import PPOLagSaute -from omnisafe.algorithms.on_policy.saute.ppo_saute import PPOSaute -from omnisafe.algorithms.on_policy.second_order.cpo import CPO -from omnisafe.algorithms.on_policy.second_order.pcpo import PCPO -from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_pid import PPOLagSimmerPid -from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_q import PPOLagSimmerQ -from omnisafe.algorithms.on_policy.simmer.ppo_simmer_pid import PPOSimmerPid -from omnisafe.algorithms.on_policy.simmer.ppo_simmer_q import PPOSimmerQ ALGORITHMS = { - 'off-policy': ( - 'DDPG', - 'DDPGLag', - 'TD3', - 'TD3Lag', - 'SAC', - 'SACLag', - 'SDDPG', - ), - 'on-policy': ( - 'PolicyGradient', - 'NaturalPG', - 'TRPO', - 'PPO', - 'PDO', - 'NPGLag', - 'TRPOLag', - 'PPOLag', - 'CPPOPid', - 'TRPOPid', - 'FOCOPS', - 'CUP', - 'CPO', - 'PCPO', - 'PPOSimmerPid', - 'PPOSimmerQ', - 'PPOLagSimmerQ', - 'PPOLagSimmerPid', - 'PPOSaute', - 'PPOLagSaute', - 'PPOEarlyTerminated', - 'PPOLagEarlyTerminated', - ), + 'off-policy': tuple(off_policy.__all__), + 'on-policy': tuple(on_policy.__all__), 'model-based': ( 'MBPPOLag', 'SafeLoop', diff --git a/omnisafe/algorithms/off_policy/__init__.py b/omnisafe/algorithms/off_policy/__init__.py index f96f861ff..aa4ea363b 100644 --- a/omnisafe/algorithms/off_policy/__init__.py +++ b/omnisafe/algorithms/off_policy/__init__.py @@ -13,3 +13,22 @@ # limitations under the License. # ============================================================================== """Off-policy algorithms.""" + +from omnisafe.algorithms.off_policy.ddpg import DDPG +from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag +from omnisafe.algorithms.off_policy.sac import SAC +from omnisafe.algorithms.off_policy.sac_lag import SACLag +from omnisafe.algorithms.off_policy.sddpg import SDDPG +from omnisafe.algorithms.off_policy.td3 import TD3 +from omnisafe.algorithms.off_policy.td3_lag import TD3Lag + + +__all__ = [ + 'DDPG', + 'DDPGLag', + 'SAC', + 'SACLag', + 'SDDPG', + 'TD3', + 'TD3Lag', +] diff --git a/omnisafe/algorithms/on_policy/__init__.py b/omnisafe/algorithms/on_policy/__init__.py index b71633f91..c7a781f09 100644 --- a/omnisafe/algorithms/on_policy/__init__.py +++ b/omnisafe/algorithms/on_policy/__init__.py @@ -13,3 +13,39 @@ # limitations under the License. # ============================================================================== """On-policy algorithms.""" + +from omnisafe.algorithms.on_policy import ( + base, + early_terminated, + first_order, + naive_lagrange, + pid_lagrange, + saute, + second_order, + simmer, +) +from omnisafe.algorithms.on_policy.base import PPO, TRPO, NaturalPG, PolicyGradient +from omnisafe.algorithms.on_policy.early_terminated import PPOEarlyTerminated, PPOLagEarlyTerminated +from omnisafe.algorithms.on_policy.first_order import CUP, FOCOPS +from omnisafe.algorithms.on_policy.naive_lagrange import PDO, NPGLag, PPOLag, TRPOLag +from omnisafe.algorithms.on_policy.pid_lagrange import CPPOPid, TRPOPid +from omnisafe.algorithms.on_policy.saute import PPOLagSaute, PPOSaute +from omnisafe.algorithms.on_policy.second_order import CPO, PCPO +from omnisafe.algorithms.on_policy.simmer import ( + PPOLagSimmerPid, + PPOLagSimmerQ, + PPOSimmerPid, + PPOSimmerQ, +) + + +__all__ = [ + *base.__all__, + *early_terminated.__all__, + *first_order.__all__, + *naive_lagrange.__all__, + *pid_lagrange.__all__, + *saute.__all__, + *second_order.__all__, + *simmer.__all__, +] diff --git a/omnisafe/algorithms/on_policy/base/__init__.py b/omnisafe/algorithms/on_policy/base/__init__.py index 434672651..0b8e240be 100644 --- a/omnisafe/algorithms/on_policy/base/__init__.py +++ b/omnisafe/algorithms/on_policy/base/__init__.py @@ -13,3 +13,16 @@ # limitations under the License. # ============================================================================== """Basic Reinforcement Learning algorithms.""" + +from omnisafe.algorithms.on_policy.base.natural_pg import NaturalPG +from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient +from omnisafe.algorithms.on_policy.base.ppo import PPO +from omnisafe.algorithms.on_policy.base.trpo import TRPO + + +__all__ = [ + 'NaturalPG', + 'PolicyGradient', + 'PPO', + 'TRPO', +] diff --git a/omnisafe/algorithms/on_policy/early_terminated/__init__.py b/omnisafe/algorithms/on_policy/early_terminated/__init__.py index aa270fe80..457ca0b3e 100644 --- a/omnisafe/algorithms/on_policy/early_terminated/__init__.py +++ b/omnisafe/algorithms/on_policy/early_terminated/__init__.py @@ -13,3 +13,14 @@ # limitations under the License. # ============================================================================== """Early terminated algorithms.""" + +from omnisafe.algorithms.on_policy.early_terminated.ppo_early_terminated import PPOEarlyTerminated +from omnisafe.algorithms.on_policy.early_terminated.ppo_lag_early_terminated import ( + PPOLagEarlyTerminated, +) + + +__all__ = [ + 'PPOEarlyTerminated', + 'PPOLagEarlyTerminated', +] diff --git a/omnisafe/algorithms/on_policy/first_order/__init__.py b/omnisafe/algorithms/on_policy/first_order/__init__.py index 7ff8122ea..630eedaa6 100644 --- a/omnisafe/algorithms/on_policy/first_order/__init__.py +++ b/omnisafe/algorithms/on_policy/first_order/__init__.py @@ -13,3 +13,12 @@ # limitations under the License. # ============================================================================== """The first order algorithms.""" + +from omnisafe.algorithms.on_policy.first_order.cup import CUP +from omnisafe.algorithms.on_policy.first_order.focops import FOCOPS + + +__all__ = [ + 'CUP', + 'FOCOPS', +] diff --git a/omnisafe/algorithms/on_policy/naive_lagrange/__init__.py b/omnisafe/algorithms/on_policy/naive_lagrange/__init__.py index e575cd4c6..018b41197 100644 --- a/omnisafe/algorithms/on_policy/naive_lagrange/__init__.py +++ b/omnisafe/algorithms/on_policy/naive_lagrange/__init__.py @@ -13,3 +13,16 @@ # limitations under the License. # ============================================================================== """Naive Lagrange algorithms.""" + +from omnisafe.algorithms.on_policy.naive_lagrange.npg_lag import NPGLag +from omnisafe.algorithms.on_policy.naive_lagrange.pdo import PDO +from omnisafe.algorithms.on_policy.naive_lagrange.ppo_lag import PPOLag +from omnisafe.algorithms.on_policy.naive_lagrange.trpo_lag import TRPOLag + + +__all__ = [ + 'NPGLag', + 'PDO', + 'PPOLag', + 'TRPOLag', +] diff --git a/omnisafe/algorithms/on_policy/pid_lagrange/__init__.py b/omnisafe/algorithms/on_policy/pid_lagrange/__init__.py index 25592db8a..aef2373d6 100644 --- a/omnisafe/algorithms/on_policy/pid_lagrange/__init__.py +++ b/omnisafe/algorithms/on_policy/pid_lagrange/__init__.py @@ -13,3 +13,12 @@ # limitations under the License. # ============================================================================== """PID Lagrange algorithms.""" + +from omnisafe.algorithms.on_policy.pid_lagrange.cppo_pid import CPPOPid +from omnisafe.algorithms.on_policy.pid_lagrange.trpo_pid import TRPOPid + + +__all__ = [ + 'CPPOPid', + 'TRPOPid', +] diff --git a/omnisafe/algorithms/on_policy/saute/__init__.py b/omnisafe/algorithms/on_policy/saute/__init__.py index 6dab3e35c..65e0a5087 100644 --- a/omnisafe/algorithms/on_policy/saute/__init__.py +++ b/omnisafe/algorithms/on_policy/saute/__init__.py @@ -13,3 +13,12 @@ # limitations under the License. # ============================================================================== """Saute algorithms.""" + +from omnisafe.algorithms.on_policy.saute.ppo_lag_saute import PPOLagSaute +from omnisafe.algorithms.on_policy.saute.ppo_saute import PPOSaute + + +__all__ = [ + 'PPOLagSaute', + 'PPOSaute', +] diff --git a/omnisafe/algorithms/on_policy/second_order/__init__.py b/omnisafe/algorithms/on_policy/second_order/__init__.py index db903c4c2..ac04d723b 100644 --- a/omnisafe/algorithms/on_policy/second_order/__init__.py +++ b/omnisafe/algorithms/on_policy/second_order/__init__.py @@ -13,3 +13,12 @@ # limitations under the License. # ============================================================================== """Second order algorithms.""" + +from omnisafe.algorithms.on_policy.second_order.cpo import CPO +from omnisafe.algorithms.on_policy.second_order.pcpo import PCPO + + +__all__ = [ + 'CPO', + 'PCPO', +] diff --git a/omnisafe/algorithms/on_policy/simmer/__init__.py b/omnisafe/algorithms/on_policy/simmer/__init__.py index 1a8a5d794..55fddc516 100644 --- a/omnisafe/algorithms/on_policy/simmer/__init__.py +++ b/omnisafe/algorithms/on_policy/simmer/__init__.py @@ -13,3 +13,16 @@ # limitations under the License. # ============================================================================== """Simmer algorithms.""" + +from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_pid import PPOLagSimmerPid +from omnisafe.algorithms.on_policy.simmer.ppo_lag_simmer_q import PPOLagSimmerQ +from omnisafe.algorithms.on_policy.simmer.ppo_simmer_pid import PPOSimmerPid +from omnisafe.algorithms.on_policy.simmer.ppo_simmer_q import PPOSimmerQ + + +__all__ = [ + 'PPOLagSimmerPid', + 'PPOLagSimmerQ', + 'PPOSimmerPid', + 'PPOSimmerQ', +]