diff --git a/README.md b/README.md index 506ff355a..2b9396bad 100644 --- a/README.md +++ b/README.md @@ -221,7 +221,7 @@ import omnisafe env = 'SafetyPointGoal1-v0' -custom_dict = {'epochs': 1, 'data_dir': './runs'} +custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('PPOLag', env, custom_cfgs=custom_dict) agent.learn() ``` diff --git a/docs/source/BaseRL/PPO.rst b/docs/source/BaseRL/PPO.rst index bc10f3781..75734e595 100644 --- a/docs/source/BaseRL/PPO.rst +++ b/docs/source/BaseRL/PPO.rst @@ -346,7 +346,7 @@ Quick start env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('PPO', env, custom_cfgs=custom_dict) agent.learn() diff --git a/docs/source/BaseRL/TRPO.rst b/docs/source/BaseRL/TRPO.rst index 62a14aeb5..bc8a76412 100644 --- a/docs/source/BaseRL/TRPO.rst +++ b/docs/source/BaseRL/TRPO.rst @@ -528,7 +528,7 @@ Quick start env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('TRPO', env, custom_cfgs=custom_dict) agent.learn() diff --git a/docs/source/Documentation/Introduction.rst b/docs/source/Documentation/Introduction.rst index acc88fafe..d19c9d5d6 100644 --- a/docs/source/Documentation/Introduction.rst +++ b/docs/source/Documentation/Introduction.rst @@ -257,7 +257,7 @@ We give an example below: env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('CPO', env, custom_cfgs=custom_dict) agent.learn() diff --git a/docs/source/SafeRL/CPO.rst b/docs/source/SafeRL/CPO.rst index ac902527e..b231780c9 100644 --- a/docs/source/SafeRL/CPO.rst +++ b/docs/source/SafeRL/CPO.rst @@ -454,7 +454,7 @@ Quick start env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('CPO', env, custom_cfgs=custom_dict) agent.learn() diff --git a/docs/source/SafeRL/FOCOPS.rst b/docs/source/SafeRL/FOCOPS.rst index 8d6989bde..52df6b300 100644 --- a/docs/source/SafeRL/FOCOPS.rst +++ b/docs/source/SafeRL/FOCOPS.rst @@ -454,7 +454,7 @@ Quick start env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('FOCOPS', env, custom_cfgs=custom_dict) agent.learn() diff --git a/docs/source/SafeRL/Lag.rst b/docs/source/SafeRL/Lag.rst index 33e3e765d..a9824ed00 100644 --- a/docs/source/SafeRL/Lag.rst +++ b/docs/source/SafeRL/Lag.rst @@ -326,7 +326,7 @@ Quick start env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('PPOLag', env, custom_cfgs=custom_dict) agent.learn() diff --git a/docs/source/SafeRL/PCPO.rst b/docs/source/SafeRL/PCPO.rst index 4f15fe7c0..bebdcfe58 100644 --- a/docs/source/SafeRL/PCPO.rst +++ b/docs/source/SafeRL/PCPO.rst @@ -442,7 +442,7 @@ Quick start env = omnisafe.Env('SafetyPointGoal1-v0') - custom_dict = {'epochs': 1, 'data_dir': './runs'} + custom_dict = {'epochs': 1, 'log_dir': './runs'} agent = omnisafe.Agent('PCPO', env, custom_cfgs=custom_dict) agent.learn() diff --git a/examples/benchmarks/run_experiment_grid.py b/examples/benchmarks/run_experiment_grid.py index 56634f042..0d6b4774a 100644 --- a/examples/benchmarks/run_experiment_grid.py +++ b/examples/benchmarks/run_experiment_grid.py @@ -25,7 +25,7 @@ def train( - exp_id: str, algo: str, env_id: str, custom_cfgs: NamedTuple, num_threads: int = 6 + exp_id: str, algo: str, env_id: str, custom_cfgs: NamedTuple ) -> Tuple[float, float, float]: """Train a policy from exp-x config with OmniSafe. @@ -36,16 +36,19 @@ def train( custom_cfgs (NamedTuple): Custom configurations. num_threads (int, optional): Number of threads. Defaults to 6. """ - torch.set_num_threads(num_threads) sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ print(f'exp-x: {exp_id} is training...') USE_REDIRECTION = True if USE_REDIRECTION: - if not os.path.exists(custom_cfgs['data_dir']): - os.makedirs(custom_cfgs['data_dir']) - sys.stdout = open(f'{custom_cfgs["data_dir"]}terminal.log', 'w', encoding='utf-8') - sys.stderr = open(f'{custom_cfgs["data_dir"]}error.log', 'w', encoding='utf-8') + if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']): + os.makedirs(custom_cfgs['logger_cfgs']['log_dir']) + sys.stdout = open( + f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8' + ) + sys.stderr = open( + f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8' + ) agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs) reward, cost, ep_len = agent.learn() return reward, cost, ep_len @@ -58,9 +61,11 @@ def train( first_order_policy = ['CUP', 'FOCOPS'] second_order_policy = ['CPO', 'PCPO'] eg.add('algo', base_policy + naive_lagrange_policy + first_order_policy + second_order_policy) - eg.add('env_id', ['SafetyPointGoal1-v0']) - eg.add('logger_cfgs:use_wandb', [True]) - eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming']) + eg.add('env_id', ['SafetyAntVelocity-v4']) + eg.add('logger_cfgs:use_wandb', [False]) + eg.add('num_envs', [16]) + eg.add('num_threads', [1]) + # eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming']) # eg.add('train_cfgs:total_steps', 2000) # eg.add('algo_cfgs:update_cycle', 1000) # eg.add('train_cfgs:vector_env_nums', 1) diff --git a/omnisafe/common/experiment_grid.py b/omnisafe/common/experiment_grid.py index fb5069a05..cad050f48 100644 --- a/omnisafe/common/experiment_grid.py +++ b/omnisafe/common/experiment_grid.py @@ -292,7 +292,7 @@ def unflatten_var(var): return new_variants # pylint: disable-next=too-many-locals - def run(self, thunk, num_pool=1, data_dir=None, is_test=False): + def run(self, thunk, num_pool=1, log_dir=None, is_test=False): r"""Run each variant in the grid with function 'thunk'. Note: 'thunk' must be either a callable function, or a string. If it is @@ -362,8 +362,9 @@ def run(self, thunk, num_pool=1, data_dir=None, is_test=False): print('current_config', var) exp_name = '_'.join([k + '_' + str(v) for k, v in var.items()]) exp_names.append(exp_name) - data_dir = os.path.join('./', 'exp-x', self.name, exp_name, '') - var['data_dir'] = data_dir + if log_dir is None: + log_dir = os.path.join('./', 'exp-x', self.name, exp_name, '') + var['logger_cfgs'] = {'log_dir': log_dir} results.append(pool.submit(thunk, idx, var['algo'], var['env_id'], var)) pool.shutdown() diff --git a/omnisafe/utils/config.py b/omnisafe/utils/config.py index 107c66038..c31501668 100644 --- a/omnisafe/utils/config.py +++ b/omnisafe/utils/config.py @@ -41,7 +41,7 @@ class Config(dict): num_mini_batches: int actor_lr: float critic_lr: float - data_dir: str + log_dir: str target_kl: float batch_size: int use_cost: bool