diff --git a/.pylintrc b/.pylintrc index 13c1bd408..abed81d63 100644 --- a/.pylintrc +++ b/.pylintrc @@ -196,6 +196,7 @@ good-names=i, id, e, f, + eg, # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/README.md b/README.md index 7b7bdab87..3400712f1 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ The supported interface algorithms currently include: - [X] **[AAAI 2020]** [IPO: Interior-point Policy Optimization under Constraints (IPO)](https://arxiv.org/abs/1910.09615) - [X] **[ICLR 2020]** [Projection-Based Constrained Policy Optimization (PCPO)](https://openreview.net/forum?id=rke3TJrtPS) - [X] **[ICML 2021]** [CRPO: A New Approach for Safe Reinforcement Learning with Convergence Guarantee](https://arxiv.org/abs/2011.05869) +- [x] **[IJCAI 2022]** [Penalized Proximal Policy Optimization for Safe Reinforcement Learning(P3O)](https://arxiv.org/pdf/2205.11814.pdf) #### Off-Policy Safe @@ -147,6 +148,29 @@ cd examples python train_policy.py --algo PPOLag --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1 ``` +#### Try with CLI + +```bash +pip install omnisafe + +omnisafe --help # Ask for help + +omnisafe [command] --help # Ask for command specific help + +# Quick benchmarking for your research, just specify: 1.exp_name, 2.num_pool(how much processes are concurrent), 3.path of the config file(refer to omnisafe/examples/benchmarks for format) +omnisafe benchmark test_benchmark, 2, "./saved_source/benchmark_config.yaml" + +# Quick evaluating and rendering your trained policy, just specify: 1.path of algorithm which you trained +omnisafe eval ./saved_source/PPO-{SafetyPointGoal1-v0}, "--num-episode", "1" + +# Quick training some algorithms to validate your thoughts +# Note: use `key1:key2`, your can select key of hyperparameters which are recursively contained, and use `--custom-cfgs`, you can add custom cfgs via CLI +omnisafe train --algo PPO --total-steps 1024 --vector-env-nums 1 --custom-cfgs algo_cfgs:update_cycle --custom-cfgs 512 + +# Quick training some algorithms via a saved config file, the format is as same as default format +omnisafe train-config "./saved_source/train_config.yaml" +``` + **algo:** Type | Name ---------------| ---------------------------------------------- @@ -205,7 +229,8 @@ More information about environments, please refer to [Safety Gymnasium](https:// -------------------------------------------------------------------------------- ## Getting Started - +#### Important Hints +- `train_cfgs:torch_threads`is especialy important for trainning speed, and is varying with users' machine, this value shouldn't be too small or too large. ### 1. Run Agent from preset yaml file ```python diff --git a/examples/benchmarks/example_cli_benchmark_config.yaml b/examples/benchmarks/example_cli_benchmark_config.yaml new file mode 100644 index 000000000..f4588e97d --- /dev/null +++ b/examples/benchmarks/example_cli_benchmark_config.yaml @@ -0,0 +1,31 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +algo: + ['PolicyGradient', 'NaturalPG'] +env_id: + ['SafetyAntVelocity-v4'] +logger_cfgs:use_wandb: + [False] +train_cfgs:vector_env_nums: + [2] +train_cfgs:torch_threads: + [1] +train_cfgs:total_steps: + 1024 +algo_cfgs:update_cycle: + 512 +seed: + [0] diff --git a/examples/benchmarks/run_experiment_grid.py b/examples/benchmarks/run_experiment_grid.py index 0d6b4774a..6590ccc50 100644 --- a/examples/benchmarks/run_experiment_grid.py +++ b/examples/benchmarks/run_experiment_grid.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ import os import sys -import torch - import omnisafe from omnisafe.common.experiment_grid import ExperimentGrid from omnisafe.typing import NamedTuple, Tuple @@ -39,16 +37,10 @@ def train( sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ print(f'exp-x: {exp_id} is training...') - USE_REDIRECTION = True - if USE_REDIRECTION: - if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']): - os.makedirs(custom_cfgs['logger_cfgs']['log_dir']) - sys.stdout = open( - f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8' - ) - sys.stderr = open( - f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8' - ) + if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']): + os.makedirs(custom_cfgs['logger_cfgs']['log_dir']) + sys.stdout = open(f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8') + sys.stderr = open(f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8') agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs) reward, cost, ep_len = agent.learn() return reward, cost, ep_len @@ -58,16 +50,14 @@ def train( eg = ExperimentGrid(exp_name='Safety_Gymnasium_Goal') base_policy = ['PolicyGradient', 'NaturalPG', 'TRPO', 'PPO'] naive_lagrange_policy = ['PPOLag', 'TRPOLag', 'RCPO', 'OnCRPO', 'PDO'] - first_order_policy = ['CUP', 'FOCOPS'] + first_order_policy = ['CUP', 'FOCOPS', 'P3O'] second_order_policy = ['CPO', 'PCPO'] eg.add('algo', base_policy + naive_lagrange_policy + first_order_policy + second_order_policy) eg.add('env_id', ['SafetyAntVelocity-v4']) eg.add('logger_cfgs:use_wandb', [False]) - eg.add('num_envs', [16]) - eg.add('num_threads', [1]) - # eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming']) - # eg.add('train_cfgs:total_steps', 2000) - # eg.add('algo_cfgs:update_cycle', 1000) - # eg.add('train_cfgs:vector_env_nums', 1) + eg.add('train_cfgs:vector_env_nums', [4]) + eg.add('train_cfgs:torch_threads', [1]) eg.add('seed', [0]) - eg.run(train, num_pool=13) + # total experiment num must can be divided by num_pool + # meanwhile, users should decide this value according to their machine + eg.run(train, num_pool=14) diff --git a/examples/evaluate_saved_policy.py b/examples/evaluate_saved_policy.py index 6fcd8ea45..226a8509f 100644 --- a/examples/evaluate_saved_policy.py +++ b/examples/evaluate_saved_policy.py @@ -1,4 +1,4 @@ -# Copyright 2022 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,10 +22,8 @@ # Just fill your experiment's log directory in here. # Such as: ~/omnisafe/examples/runs/PPOLag-/seed-000-2023-03-07-20-25-48 LOG_DIR = '' -play = True -save_replay = True if __name__ == '__main__': - evaluator = omnisafe.Evaluator(play=play, save_replay=save_replay) + evaluator = omnisafe.Evaluator(render_mode='rgb_array') for item in os.scandir(os.path.join(LOG_DIR, 'torch_save')): if item.is_file() and item.name.split('.')[-1] == 'pt': evaluator.load_saved( diff --git a/examples/train_from_custom_dict.py b/examples/train_from_custom_dict.py index 816c35ffb..a1efd8a91 100644 --- a/examples/train_from_custom_dict.py +++ b/examples/train_from_custom_dict.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/omnisafe/adapter/onpolicy_adapter.py b/omnisafe/adapter/onpolicy_adapter.py index 4e8ab45e3..40a26e28d 100644 --- a/omnisafe/adapter/onpolicy_adapter.py +++ b/omnisafe/adapter/onpolicy_adapter.py @@ -17,6 +17,7 @@ from typing import Dict, Optional import torch +from rich.progress import track from omnisafe.adapter.online_adapter import OnlineAdapter from omnisafe.common.buffer import VectorOnPolicyBuffer @@ -56,7 +57,10 @@ def roll_out( # pylint: disable=too-many-locals self._reset_log() obs, _ = self.reset() - for step in range(steps_per_epoch): + for step in track( + range(steps_per_epoch), + description=f'Processing rollout for epoch: {logger.current_epoch}...', + ): act, value_r, value_c, logp = agent.step(obs) next_obs, reward, cost, terminated, truncated, info = self.step(act) diff --git a/omnisafe/algorithms/algo_wrapper.py b/omnisafe/algorithms/algo_wrapper.py index fad89a2a8..fe54560be 100644 --- a/omnisafe/algorithms/algo_wrapper.py +++ b/omnisafe/algorithms/algo_wrapper.py @@ -25,6 +25,7 @@ from omnisafe.envs import support_envs from omnisafe.utils import distributed from omnisafe.utils.config import check_all_configs, get_default_kwargs_yaml +from omnisafe.utils.tools import recursive_check_config class AlgoWrapper: @@ -66,13 +67,17 @@ def _init_config(self): # update the cfgs from custom configurations if self.custom_cfgs: + recursive_check_config(self.custom_cfgs, cfgs, exclude_keys=('algo', 'env_id')) cfgs.recurisve_update(self.custom_cfgs) # update the cfgs from custom terminal configurations if self.train_terminal_cfgs: + recursive_check_config( + self.train_terminal_cfgs, cfgs.train_cfgs, exclude_keys=('algo', 'env_id') + ) cfgs.train_cfgs.recurisve_update(self.train_terminal_cfgs) - # the exp_name format is PPO-(SafetyPointGoal1-v0)- - exp_name = f'{self.algo}-({self.env_id})' + # the exp_name format is PPO-{SafetyPointGoal1-v0}- + exp_name = f'{self.algo}-{{{self.env_id}}}' cfgs.recurisve_update({'exp_name': exp_name, 'env_id': self.env_id}) cfgs.train_cfgs.recurisve_update( {'epochs': cfgs.train_cfgs.total_steps // cfgs.algo_cfgs.update_cycle} diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index 18fdc18cd..0fd98b421 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from rich.progress import track from torch.utils.data import DataLoader, TensorDataset from omnisafe.adapter import OnPolicyAdapter @@ -219,7 +220,7 @@ def _update(self) -> None: shuffle=True, ) - for i in range(self._cfgs.algo_cfgs.update_iters): + for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'): for ( obs, act, @@ -245,12 +246,12 @@ def _update(self) -> None: kl = distributed.dist_avg(kl) if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl: - self._logger.log(f'Early stopping at iter {i} due to reaching max kl') + self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl') break self._logger.store( **{ - 'Train/StopIter': i + 1, + 'Train/StopIter': i + 1, # pylint: disable=undefined-loop-variable 'Value/Adv': adv_r.mean().item(), 'Train/KL': kl, } diff --git a/omnisafe/algorithms/on_policy/first_order/cup.py b/omnisafe/algorithms/on_policy/first_order/cup.py index 85adc53cc..6f0e85e3a 100644 --- a/omnisafe/algorithms/on_policy/first_order/cup.py +++ b/omnisafe/algorithms/on_policy/first_order/cup.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ """Implementation of the CUP algorithm.""" import torch +from rich.progress import track from torch.distributions import Normal from torch.utils.data import DataLoader, TensorDataset @@ -142,7 +143,7 @@ def _update(self) -> None: shuffle=True, ) - for i in range(self._cfgs.algo_cfgs.update_iters): + for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'): for obs, act, logp, adv_c, old_mean, old_std in dataloader: self._p_dist = Normal(old_mean, old_std) loss_cost, info = self._loss_pi_cost(obs, act, logp, adv_c) @@ -166,7 +167,7 @@ def _update(self) -> None: kl = distributed.dist_avg(kl) if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl: - self._logger.log(f'Early stopping at iter {i} due to reaching max kl') + self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl') break self._logger.store( @@ -174,7 +175,7 @@ def _update(self) -> None: 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.item(), 'Train/MaxRatio': self._max_ratio, 'Train/MinRatio': self._min_ratio, - 'Train/SecondStepStopIter': i + 1, + 'Train/SecondStepStopIter': i + 1, # pylint: disable=undefined-loop-variable 'Train/SecondStepEntropy': info['entropy'], 'Train/SecondStepPolicyRatio': info['ratio'], } diff --git a/omnisafe/algorithms/on_policy/first_order/focops.py b/omnisafe/algorithms/on_policy/first_order/focops.py index 3a4185dea..08f2a54fd 100644 --- a/omnisafe/algorithms/on_policy/first_order/focops.py +++ b/omnisafe/algorithms/on_policy/first_order/focops.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ from typing import Dict, Tuple import torch +from rich.progress import track from torch.distributions import Normal from torch.utils.data import DataLoader, TensorDataset @@ -108,7 +109,7 @@ def _update(self) -> None: shuffle=True, ) - for i in range(self._cfgs.algo_cfgs.update_iters): + for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'): for ( obs, act, @@ -138,12 +139,12 @@ def _update(self) -> None: kl = distributed.dist_avg(kl) if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl: - self._logger.log(f'Early stopping at iter {i} due to reaching max kl') + self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl') break self._logger.store( **{ - 'Train/StopIter': i + 1, + 'Train/StopIter': i + 1, # pylint: disable=undefined-loop-variable 'Value/Adv': adv_r.mean().item(), 'Train/KL': kl, 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier, diff --git a/omnisafe/common/experiment_grid.py b/omnisafe/common/experiment_grid.py index 1adc17709..555d55e61 100644 --- a/omnisafe/common/experiment_grid.py +++ b/omnisafe/common/experiment_grid.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,10 +24,12 @@ from typing import Any, Dict, List import numpy as np +from rich.console import Console from tqdm import trange -from omnisafe.common.logger import WordColor +from omnisafe.algorithms import ALGORITHM2TYPE from omnisafe.utils.exp_grid_tools import all_bools, valid_str +from omnisafe.utils.tools import load_yaml, recursive_check_config # pylint: disable-next=too-many-instance-attributes @@ -42,6 +44,8 @@ def __init__(self, exp_name='') -> None: self.div_line_width = 80 assert isinstance(exp_name, str), 'Name has to be a string.' self.name = exp_name + self._console = Console() + self.log_dir: str # Whether GridSearch provides automatically-generated default shorthands self.default_shorthand = True @@ -65,12 +69,12 @@ def print(self) -> None: msg = base_msg % name_insert else: msg = base_msg % (name_insert + '\n') - print(WordColor.colorize(msg, color='green', bold=True)) + self._console.print(msg, style='green bold') # List off parameters, shorthands, and possible values. for key, value, shorthand in zip(self.keys, self.vals, self.shs): - color_k = WordColor.colorize(key.ljust(40), color='cyan', bold=True) - print('', color_k, '[' + shorthand + ']' if shorthand is not None else '', '\n') + self._console.print('', key.ljust(40), style='cyan bold', end='') + print('[' + shorthand + ']' if shorthand is not None else '', '\n') for _, val in enumerate(value): print('\t' + json.dumps(val, indent=4, sort_keys=True)) print() @@ -317,31 +321,27 @@ def run(self, thunk, num_pool=1, parent_dir=None, is_test=False): var_names = {self.variant_name(var) for var in variants} var_names = sorted(list(var_names)) line = '=' * self.div_line_width - preparing = WordColor.colorize( - 'Preparing to run the following experiments...', color='green', bold=True - ) + + self._console.print('\nPreparing to run the following experiments...', style='bold green') joined_var_names = '\n'.join(var_names) - announcement = f'\n{preparing}\n\n{joined_var_names}\n\n{line}' + announcement = f'\n{joined_var_names}\n\n{line}' print(announcement) if self.wait_defore_launch > 0: - delay_msg = ( - WordColor.colorize( - dedent( - """ + self._console.print( + dedent( + """ Launch delayed to give you a few seconds to review your experiments. To customize or disable this behavior, change WAIT_BEFORE_LAUNCH in spinup/user_config.py. """ - ), - color='cyan', - bold=True, - ) - + line + ), + style='cyan, bold', + end='', ) - print(delay_msg) + print(line) wait, steps = self.wait_defore_launch, 100 prog_bar = trange( steps, @@ -358,15 +358,17 @@ def run(self, thunk, num_pool=1, parent_dir=None, is_test=False): # run the variants. results = [] exp_names = [] + for idx, var in enumerate(variants): + self.check_variant_vaild(var) print('current_config', var) exp_name = '_'.join([k + '_' + str(v) for k, v in var.items()]) exp_names.append(exp_name) if parent_dir is None: - log_dir = os.path.join('./', 'exp-x', self.name, exp_name, '') + self.log_dir = os.path.join('./', 'exp-x', self.name, exp_name, '') else: - log_dir = os.path.join(parent_dir, self.name, exp_name, '') - var['logger_cfgs'] = {'log_dir': log_dir} + self.log_dir = os.path.join(parent_dir, self.name, exp_name, '') + var['logger_cfgs'] = {'log_dir': self.log_dir} results.append(pool.submit(thunk, idx, var['algo'], var['env_id'], var)) pool.shutdown() @@ -375,7 +377,7 @@ def run(self, thunk, num_pool=1, parent_dir=None, is_test=False): def save_results(self, exp_names, variants, results): """Save results to a file.""" - path = os.path.join('./', 'exp-x', self.name, 'exp-x-results.txt') + path = os.path.join(self.log_dir, '..', 'exp-x-results.txt') str_len = max(len(exp_name) for exp_name in exp_names) exp_names = [exp_name.ljust(str_len) for exp_name in exp_names] with open(path, 'a+', encoding='utf-8') as f: @@ -386,3 +388,11 @@ def save_results(self, exp_names, variants, results): f.write('cost:' + str(round(cost, 2)) + ',') f.write('ep_len:' + str(ep_len)) f.write('\n') + + def check_variant_vaild(self, variant): + """Check if the variant is valid.""" + path = os.path.dirname(os.path.abspath(__file__)) + algo_type = ALGORITHM2TYPE.get(variant['algo'], '') + cfg_path = os.path.join(path, '..', 'configs', algo_type, f"{variant['algo']}.yaml") + default_config = load_yaml(cfg_path)['defaults'] + recursive_check_config(variant, default_config, exclude_keys=('algo', 'env_id')) diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 801a973ea..3bd12821f 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +19,15 @@ import os import time from collections import deque -from typing import Any, Deque, Dict, List, Literal, Optional, TextIO, Tuple, Union +from typing import Any, Deque, Dict, List, Optional, TextIO, Tuple, Union import numpy as np import torch import tqdm import wandb +from rich import print # pylint: disable=redefined-builtin +from rich.console import Console +from rich.table import Table from omnisafe.utils.config import Config from omnisafe.utils.distributed import dist_statistics_scalar, get_rank @@ -45,46 +48,6 @@ from torch.utils.tensorboard import SummaryWriter # isort:skip -ColorType = Literal['gray', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white', 'crimson'] - - -class WordColor: # pylint: disable=too-few-public-methods - """Implementation of the WordColor.""" - - GRAY: int = 30 - RED: int = 31 - GREEN: int = 32 - YELLOW: int = 33 - BLUE: int = 34 - MAGENTA: int = 35 - CYAN: int = 36 - WHITE: int = 37 - CRIMSON: int = 38 - - @staticmethod - def colorize(msg: str, color: str, bold: bool = False, highlight: bool = False) -> str: - """Colorize a message. - - Args: - msg (str): message to be colorized. - color (str): color of the message. - bold (bool): whether to use bold font. - highlight (bool): whether to use highlight. - - Returns: - str: colorized message. - """ - assert color.upper() in WordColor.__dict__, f'Invalid color: {color}' - color_code = WordColor.__dict__[color.upper()] - attr = [] - if highlight: - color_code += 10 - attr.append(str(color_code)) - if bold: - attr.append('1') - return f'\x1b[{";".join(attr)}m{msg}\x1b[0m' - - class Logger: # pylint: disable=too-many-instance-attributes """Implementation of the Logger. @@ -115,6 +78,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self._log_dir = os.path.join(output_dir, exp_name, relpath) self._verbose = verbose self._maste_proc = get_rank() == 0 + self._console = Console() self._output_file: TextIO if self._maste_proc: @@ -162,9 +126,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ), 'epochs must be specified in the config file when verbose is False' self._proc_bar = tqdm.tqdm(total=self._config['epochs'], desc='Epochs') - def log( - self, msg: str, color: ColorType = 'green', bold: bool = False, highlight: bool = False - ) -> None: + def log(self, msg: str, color: str = 'green', bold: bool = False) -> None: """Log the message to the console and the file. Args: @@ -172,7 +134,8 @@ def log( color (int): The color of the message. """ if self._verbose and self._maste_proc: - print(WordColor.colorize(msg, color, bold, highlight)) + style = ' '.join([color, 'bold' if bold else '']) + self._console.print(msg, style=style) def save_config(self, config: Config) -> None: """Save the configuration to the log directory. @@ -265,16 +228,15 @@ def store(self, **kwargs: Union[int, float, np.ndarray, torch.Tensor]) -> None: def dump_tabular(self) -> None: """Dump the tabular data to the console and the file.""" self._update_current_row() + table = Table('Metrics', 'Value') if self._maste_proc: self._epoch += 1 if self._verbose: key_lens = list(map(len, self._current_row.keys())) max_key_len = max(15, *key_lens) - n_slashes = 22 + max_key_len - print('-' * n_slashes) for key, val in self._current_row.items(): - print(f'| {key:<{max_key_len}} | {val:15.6g} |') - print('-' * n_slashes) + table.add_row(key[:max_key_len], str(val)[:max_key_len]) + else: self._proc_bar.update(1) @@ -291,6 +253,27 @@ def dump_tabular(self) -> None: if self._use_wandb: wandb.log(self._current_row, step=self._epoch) + self._console.print(table) + + def _update_current_row(self) -> None: + for key in self._data: + if self._headers_minmax[key]: + old_data = self._current_row[f'{key}/Mean'] + mean, min_val, max_val, std = self.get_stats(key, True) + self._current_row[f'{key}/Mean'] = mean + self._current_row[f'{key}/Min'] = min_val + self._current_row[f'{key}/Max'] = max_val + self._current_row[f'{key}/Std'] = std + else: + old_data = self._current_row[key] + mean = self.get_stats(key, False)[0] + self._current_row[key] = mean + + if self._headers_delta[key]: + self._current_row[f'{key}/Delta'] = mean - old_data + + if self._headers_windwos[key] is None: + self._data[key] = [] def _update_current_row(self) -> None: for key in self._data: @@ -330,6 +313,11 @@ def get_stats(self, key, min_and_max: bool = False) -> Tuple[Union[int, float], ) return (mean.item(),) + @property + def current_epoch(self) -> int: + """Return the current epoch.""" + return self._epoch + def close(self) -> None: """Close the logger.""" if self._maste_proc: diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 2d060bb92..19139ba4d 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -38,8 +38,7 @@ class Evaluator: # pylint: disable=too-many-instance-attributes # pylint: disable-next=too-many-arguments def __init__( self, - play: bool = True, - save_replay: bool = True, + render_mode: str = None, ): """Initialize the evaluator. @@ -57,15 +56,12 @@ def __init__( self._save_dir: str self._model_name: str - # set the render mode - self._play = play - self._save_replay = save_replay - self._dividing_line = '\n' + '#' * 50 + '\n' - self.__set_render_mode(play, save_replay) + if render_mode: + self.__set_render_mode(render_mode) - def __set_render_mode(self, play: bool = True, save_replay: bool = True): + def __set_render_mode(self, render_mode: str): """Set the render mode. Args: @@ -73,12 +69,8 @@ def __set_render_mode(self, play: bool = True, save_replay: bool = True): save_replay (bool): whether to save the video. """ # set the render mode - if play and save_replay: - self._render_mode = 'rgb_array' - elif play and not save_replay: - self._render_mode = 'human' - elif not play and save_replay: - self._render_mode = 'rgb_array_list' + if render_mode in ['human', 'rgb_array', 'rgb_array_list']: + self._render_mode = render_mode else: raise NotImplementedError('The render mode is not implemented.') @@ -231,7 +223,7 @@ def evaluate( print('Evaluation results:') print(f'Average episode reward: {np.mean(episode_rewards)}') print(f'Average episode cost: {np.mean(episode_costs)}') - print(f'Average episode length: {np.mean(episode_lengths)+1}') + print(f'Average episode length: {np.mean(episode_lengths)}') return ( episode_rewards, episode_costs, @@ -257,6 +249,7 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc num_episodes: int = 0, save_replay_path: Optional[str] = None, max_render_steps: int = 2000, + cost_criteria: float = 1.0, ): """Render the environment for one episode. @@ -267,6 +260,10 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc if save_replay_path is None: save_replay_path = os.path.join(self._save_dir, 'video', self._model_name.split('.')[0]) + result_path = os.path.join(save_replay_path, 'result.txt') + print(self._dividing_line) + print(f'Saving the replay video to {save_replay_path},\n and the result to {result_path}.') + print(self._dividing_line) horizon = 1000 frames = [] @@ -276,17 +273,25 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc elif self._render_mode == 'rgb_array': frames.append(self._env.render()) + episode_rewards: List[float] = [] + episode_costs: List[float] = [] + episode_lengths: List[float] = [] + for episode_idx in range(num_episodes): step = 0 done = False + ep_ret, ep_cost, length = 0.0, 0.0, 0.0 while ( not done and step <= max_render_steps ): # a big number to make sure the episode will end with torch.no_grad(): act = self._actor.predict(obs, deterministic=False) - obs, _, _, terminated, truncated, _ = self._env.step(act) + obs, rew, cost, terminated, truncated, _ = self._env.step(act) step += 1 done = bool(terminated or truncated) + ep_ret += rew.item() + ep_cost += (cost_criteria**length) * cost.item() + length += 1 if self._render_mode == 'rgb_array': frames.append(self._env.render()) @@ -306,3 +311,17 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc ) self._env.reset() frames = [] + episode_rewards.append(ep_ret) + episode_costs.append(ep_cost) + episode_lengths.append(length) + with open(result_path, 'a+', encoding='utf-8') as f: + print(f'Episode {episode_idx+1} results:', file=f) + print(f'Episode reward: {ep_ret}', file=f) + print(f'Episode cost: {ep_cost}', file=f) + print(f'Episode length: {length}', file=f) + with open(result_path, 'a+', encoding='utf-8') as f: + print(self._dividing_line) + print('Evaluation results:', file=f) + print(f'Average episode reward: {np.mean(episode_rewards)}', file=f) + print(f'Average episode cost: {np.mean(episode_costs)}', file=f) + print(f'Average episode length: {np.mean(episode_lengths)}', file=f) diff --git a/omnisafe/utils/command_app.py b/omnisafe/utils/command_app.py new file mode 100644 index 000000000..e74a12cc5 --- /dev/null +++ b/omnisafe/utils/command_app.py @@ -0,0 +1,199 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the command interfaces.""" + +import os +import sys +from typing import List + +import numpy as np +import typer +import yaml + +import omnisafe +from omnisafe.common.experiment_grid import ExperimentGrid +from omnisafe.typing import NamedTuple, Tuple +from omnisafe.utils.tools import custom_cfgs_to_dict, update_dic + + +app = typer.Typer() + + +@app.command() +def train( # pylint: disable=too-many-arguments + algo: str = typer.Option( + 'PPOLag', help=f"algorithm to train{omnisafe.ALGORITHMS['all']}", case_sensitive=False + ), + env_id: str = typer.Option( + 'SafetyHumanoidVelocity-v4', help='the name of test environment', case_sensitive=False + ), + parallel: int = typer.Option(1, help='number of paralleled progress for calculations.'), + total_steps: int = typer.Option(1638400, help='total number of steps to train for algorithm'), + device: str = typer.Option('cpu', help='device to use for training'), + vector_env_nums: int = typer.Option(16, help='number of vector envs to use for training'), + torch_threads: int = typer.Option(16, help='number of threads to use for torch'), + log_dir: str = typer.Option( + os.path.join(os.getcwd()), help='directory to save logs, default is current directory' + ), + custom_cfgs: List[str] = typer.Option([], help='custom configuration for training'), +): + """Train a single policy in OmniSafe via command line.""" + args = { + 'algo': algo, + 'env_id': env_id, + 'parallel': parallel, + 'total_steps': total_steps, + 'device': device, + 'vector_env_nums': vector_env_nums, + 'torch_threads': torch_threads, + } + keys = custom_cfgs[0::2] + values = list(custom_cfgs[1::2]) + custom_cfgs = dict(zip(keys, values)) + custom_cfgs.update({'logger_cfgs:log_dir': os.path.join(log_dir, 'train')}) + + parsed_custom_cfgs = {} + for k, v in custom_cfgs.items(): + update_dic(parsed_custom_cfgs, custom_cfgs_to_dict(k, v)) + + agent = omnisafe.Agent( + algo=algo, + env_id=env_id, + train_terminal_cfgs=args, + custom_cfgs=parsed_custom_cfgs, + ) + agent.learn() + + +def train_grid( + exp_id: str, algo: str, env_id: str, custom_cfgs: NamedTuple +) -> Tuple[float, float, float]: + """Train a policy from exp-x config with OmniSafe. + + Args: + exp_id (str): Experiment ID. + algo (str): Algorithm to train. + env_id (str): The name of test environment. + custom_cfgs (NamedTuple): Custom configurations. + num_threads (int, optional): Number of threads. Defaults to 6. + """ + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + print(f'exp-x: {exp_id} is training...') + if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']): + os.makedirs(custom_cfgs['logger_cfgs']['log_dir']) + # pylint: disable=consider-using-with + sys.stdout = open(f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8') + # pylint: disable=consider-using-with + sys.stderr = open(f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8') + agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs) + reward, cost, ep_len = agent.learn() + return reward, cost, ep_len + + +@app.command() +def benchmark( + exp_name: str = typer.Argument(..., help='experiment name'), + num_pool: int = typer.Argument(..., help='number of paralleled experiments.'), + config_path: str = typer.Argument( + ..., help='path to config file, it is supposed to be yaml file, e.g. ./configs/ppo.yaml' + ), + log_dir: str = typer.Option( + os.path.join(os.getcwd()), help='directory to save logs, default is current directory' + ), +): + """Benchmark algorithms configured by .yaml file in OmniSafe via command line.""" + assert config_path.endswith('.yaml'), 'config file must be yaml file' + with open(config_path, encoding='utf-8') as file: + try: + configs = yaml.load(file, Loader=yaml.FullLoader) + assert configs is not None, 'load file error' + except yaml.YAMLError as exc: + assert False, f'load file error: {exc}' + assert 'algo' in configs, 'algo must be specified in config file' + assert 'env_id' in configs, 'env_id must be specified in config file' + assert ( + np.prod([len(v) if isinstance(v, list) else 1 for v in configs.values()]) % num_pool == 0 + ), 'total number of experiments must can be divided by num_pool' + log_dir = os.path.join(log_dir, 'benchmark') + eg = ExperimentGrid(exp_name=exp_name) + for k, v in configs.items(): + eg.add(key=k, vals=v) + eg.run(train_grid, num_pool=num_pool, parent_dir=log_dir) + + +@app.command('eval') +def evaluate( + result_dir: str = typer.Argument( + ..., + help='directory of experiment results to evaluate, e.g. ./runs/PPO-{SafetyPointGoal1-v0}', + ), + num_episode: int = typer.Option(10, help='number of episodes to render'), + render: bool = typer.Option(True, help='whether to render'), + render_mode: str = typer.Option( + 'rgb_array', + help="render mode('human', 'rgb_array', 'rgb_array_list', 'depth_array', 'depth_array_list')", + ), + camera_name: str = typer.Option('track', help='camera name to render'), + width: int = typer.Option(256, help='width of rendered image'), + height: int = typer.Option(256, help='height of rendered image'), +): + """Evaluate a policy which trained by OmniSafe via command line.""" + evaluator = omnisafe.Evaluator(render_mode=render_mode) + assert os.path.exists(result_dir), f'path{result_dir}, no torch_save directory' + for seed_dir in os.scandir(result_dir): + if seed_dir.is_dir(): + models_dir = os.path.join(seed_dir.path, 'torch_save') + for item in os.scandir(models_dir): + if item.is_file() and item.name.split('.')[-1] == 'pt': + evaluator.load_saved( + save_dir=seed_dir, + model_name=item.name, + camera_name=camera_name, + width=width, + height=height, + ) + if render: + evaluator.render(num_episodes=num_episode) + evaluator.evaluate(num_episodes=num_episode) + + +@app.command() +def train_config( + config_path: str = typer.Argument( + ..., help='path to config file, it is supposed to be yaml file, e.g. ./configs/ppo.yaml' + ), + log_dir: str = typer.Option( + os.path.join(os.getcwd()), help='directory to save logs, default is current directory' + ), +): + """Train a policy configured by .yaml file in OmniSafe via command line.""" + assert config_path.endswith('.yaml'), 'config file must be yaml file' + with open(config_path, encoding='utf-8') as file: + try: + args = yaml.load(file, Loader=yaml.FullLoader) + assert args is not None, 'load file error' + except yaml.YAMLError as exc: + assert False, f'load file error: {exc}' + assert 'algo' in args, 'algo must be specified in config file' + assert 'env_id' in args, 'env_id must be specified in config file' + + args.update({'logger_cfgs': {'log_dir': os.path.join(log_dir, 'train_dict')}}) + agent = omnisafe.Agent(algo=args['algo'], env_id=args['env_id'], custom_cfgs=args) + agent.learn() + + +if __name__ == '__main__': + app() diff --git a/omnisafe/utils/config.py b/omnisafe/utils/config.py index c31501668..2e23ab1e3 100644 --- a/omnisafe/utils/config.py +++ b/omnisafe/utils/config.py @@ -18,9 +18,8 @@ import os from typing import Any, Dict, List -import yaml - from omnisafe.typing import Activation, ActorType, AdvatageEstimator, InitFunction +from omnisafe.utils.tools import load_yaml class Config(dict): @@ -163,11 +162,7 @@ def get_default_kwargs_yaml(algo: str, env_id: str, algo_type: str) -> Config: path = os.path.dirname(os.path.abspath(__file__)) cfg_path = os.path.join(path, '..', 'configs', algo_type, f'{algo}.yaml') print(f'Loading {algo}.yaml from {cfg_path}') - with open(cfg_path, encoding='utf-8') as file: - try: - kwargs = yaml.load(file, Loader=yaml.FullLoader) - except yaml.YAMLError as exc: - assert False, f'{algo}.yaml error: {exc}' + kwargs = load_yaml(cfg_path) default_kwargs = kwargs['defaults'] env_spec_kwargs = kwargs[env_id] if env_id in kwargs.keys() else None diff --git a/omnisafe/utils/tools.py b/omnisafe/utils/tools.py index 2db9a7a5e..ddd80c751 100644 --- a/omnisafe/utils/tools.py +++ b/omnisafe/utils/tools.py @@ -19,6 +19,7 @@ import numpy as np import torch +import yaml def get_flat_params_from(model: torch.nn.Module) -> torch.Tensor: @@ -151,3 +152,35 @@ def update_dic(total_dic, item_dic): else: total_value = item_value total_dic.update({idd: total_value}) + + +def load_yaml(path) -> dict: + """Get the default kwargs from ``yaml`` file. + + .. note:: + This function search the ``yaml`` file by the algorithm name and environment name. + Make sure your new implemented algorithm or environment has the same name as the yaml file. + + Args: + path (str): path of the ``yaml`` file. + """ + with open(path, encoding='utf-8') as file: + try: + kwargs = yaml.load(file, Loader=yaml.FullLoader) + except yaml.YAMLError as exc: + assert False, f'{path} error: {exc}' + + return kwargs + + +def recursive_check_config(config, default_config, exclude_keys=()): + '''Check whether config is valid in default_config.''' + for key in config.keys(): + if key not in default_config.keys() and key not in exclude_keys: + raise KeyError(f'Invalid key: {key}') + if isinstance(config[key], dict): + recursive_check_config(config[key], default_config[key]) + elif isinstance(config[key], list): + for item in config[key]: + if isinstance(item, dict): + recursive_check_config(item, default_config[key][0]) diff --git a/omnisafe/version.py b/omnisafe/version.py index 4eaf6e8d6..ea4e3aa1b 100644 --- a/omnisafe/version.py +++ b/omnisafe/version.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 OmniSafe Team. All Rights Reserved. +# Copyright 2023 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # ============================================================================== """OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning.""" -__version__ = '0.0.2' +__version__ = '0.1.1' __license__ = 'Apache License, Version 2.0' __author__ = 'OmniSafe Contributors' __release__ = False diff --git a/setup.py b/setup.py index 34b30c632..0f16ce52b 100755 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ setup( name='omnisafe', version=version.__version__, + entry_points={'console_scripts': ['omnisafe=omnisafe.utils.command_app:app']}, ) finally: if VERSION_CONTENT is not None: diff --git a/tests/saved_policy/PPO/seed-000/config.json b/tests/saved_policy/PPO/seed-000/config.json deleted file mode 100644 index 4525a7d4c..000000000 --- a/tests/saved_policy/PPO/seed-000/config.json +++ /dev/null @@ -1,78 +0,0 @@ -{ - "actor_iters": 1, - "actor_lr": 0.0003, - "batch_size": 10000, - "buffer_cfgs": { - "adv_estimation_method": "gae", - "gamma": 0.99, - "lam": 0.95, - "lam_c": 0.95, - "standardized_cost_adv": true, - "standardized_rew_adv": true - }, - "check_freq": 25, - "clip": 0.2, - "cost_gamma": 1.0, - "critic_iters": 1, - "critic_lr": 0.0003, - "critic_norm_coeff": 0.001, - "data_dir": "./runs", - "device": "cpu", - "device_id": 0, - "entropy_coef": 0.0, - "env_cfgs": { - "async_env": true, - "env_seed": 0, - "max_len": 100, - "normalized_cost": false, - "normalized_obs": false, - "normalized_rew": true, - "num_envs": 1, - "num_threads": 20 - }, - "env_id": "SafetyHumanoidVelocity-v4", - "epochs": 1, - "exp_name": "SafetyHumanoidVelocity-v4/PPO", - "exploration_noise_anneal": false, - "kl_early_stopping": true, - "linear_lr_decay": true, - "max_ep_len": 1000, - "max_grad_norm": 40, - "model_cfgs": { - "ac_kwargs": { - "pi": { - "activation": "tanh", - "clip_action": false, - "hidden_sizes": [ - 64, - 64 - ], - "output_activation": "identity", - "scale_action": false, - "std_init": 1.0, - "std_learning": true - }, - "val": { - "activation": "tanh", - "hidden_sizes": [ - 64, - 64 - ], - "num_critics": 1 - } - }, - "actor_type": "gaussian", - "shared_weights": false, - "weight_initialization_mode": "kaiming_uniform" - }, - "num_mini_batches": 64, - "penalty_param": 0.0, - "save_freq": 50, - "seed": 0, - "steps_per_epoch": 2000, - "target_kl": 0.02, - "use_cost": false, - "use_critic_norm": true, - "use_max_grad_norm": true, - "wrapper_type": "CMDPWrapper" -} diff --git a/tests/saved_policy/PPO/seed-000/tb/events.out.tfevents.1675081988.user.1133609.3 b/tests/saved_policy/PPO/seed-000/tb/events.out.tfevents.1675081988.user.1133609.3 deleted file mode 100644 index 4acef7496..000000000 Binary files a/tests/saved_policy/PPO/seed-000/tb/events.out.tfevents.1675081988.user.1133609.3 and /dev/null differ diff --git a/tests/saved_policy/PPO/seed-000/torch_save/model None.pt b/tests/saved_policy/PPO/seed-000/torch_save/model None.pt deleted file mode 100644 index 23dbbdd78..000000000 Binary files a/tests/saved_policy/PPO/seed-000/torch_save/model None.pt and /dev/null differ diff --git a/tests/saved_policy/PPOEarlyTerminated/seed-000/config.json b/tests/saved_policy/PPOEarlyTerminated/seed-000/config.json deleted file mode 100644 index 10e8d343b..000000000 --- a/tests/saved_policy/PPOEarlyTerminated/seed-000/config.json +++ /dev/null @@ -1,78 +0,0 @@ -{ - "actor_iters": 10, - "actor_lr": 0.0003, - "batch_size": 10000, - "buffer_cfgs": { - "adv_estimation_method": "gae", - "gamma": 0.99, - "lam": 0.95, - "lam_c": 0.95, - "standardized_cost_adv": true, - "standardized_rew_adv": true - }, - "check_freq": 25, - "clip": 0.2, - "cost_gamma": 1.0, - "critic_iters": 1, - "critic_lr": 0.0003, - "critic_norm_coeff": 0.001, - "data_dir": "./runs", - "device": "cpu", - "device_id": 0, - "entropy_coef": 0.0, - "env_cfgs": { - "async_env": true, - "max_len": 100, - "normalized_cost": true, - "normalized_obs": true, - "normalized_rew": true, - "num_envs": 1, - "num_threads": 20 - }, - "env_id": "SafetyHumanoidVelocity-v4", - "epochs": 1, - "exp_name": "SafetyHumanoidVelocity-v4/PPOEarlyTerminated", - "exploration_noise_anneal": false, - "kl_early_stopping": true, - "linear_lr_decay": true, - "max_ep_len": 1000, - "max_grad_norm": 40, - "model_cfgs": { - "ac_kwargs": { - "pi": { - "activation": "tanh", - "clip_action": false, - "hidden_sizes": [ - 64, - 64 - ], - "output_activation": "identity", - "scale_action": false, - "std_init": 1.0, - "std_learning": true - }, - "val": { - "activation": "tanh", - "hidden_sizes": [ - 64, - 64 - ], - "num_critics": 1 - } - }, - "actor_type": "gaussian", - "shared_weights": false, - "weight_initialization_mode": "kaiming_uniform" - }, - "num_mini_batches": 64, - "penalty_param": 0.0, - "save_freq": 100, - "seed": 0, - "standardized_obs": true, - "steps_per_epoch": 1000, - "target_kl": 0.02, - "use_cost": false, - "use_critic_norm": true, - "use_max_grad_norm": true, - "wrapper_type": "EarlyTerminatedWrapper" -} diff --git a/tests/saved_policy/PPOEarlyTerminated/seed-000/tb/events.out.tfevents.1675082211.user.1133609.30 b/tests/saved_policy/PPOEarlyTerminated/seed-000/tb/events.out.tfevents.1675082211.user.1133609.30 deleted file mode 100644 index 8abf318da..000000000 Binary files a/tests/saved_policy/PPOEarlyTerminated/seed-000/tb/events.out.tfevents.1675082211.user.1133609.30 and /dev/null differ diff --git a/tests/saved_policy/PPOEarlyTerminated/seed-000/torch_save/model None.pt b/tests/saved_policy/PPOEarlyTerminated/seed-000/torch_save/model None.pt deleted file mode 100644 index 94354d0cd..000000000 Binary files a/tests/saved_policy/PPOEarlyTerminated/seed-000/torch_save/model None.pt and /dev/null differ diff --git a/tests/saved_policy/PPOSaute/seed-000/config.json b/tests/saved_policy/PPOSaute/seed-000/config.json deleted file mode 100644 index 94622521c..000000000 --- a/tests/saved_policy/PPOSaute/seed-000/config.json +++ /dev/null @@ -1,81 +0,0 @@ -{ - "actor_iters": 10, - "actor_lr": 0.0003, - "batch_size": 10000, - "buffer_cfgs": { - "adv_estimation_method": "gae", - "gamma": 0.99, - "lam": 0.95, - "lam_c": 0.95, - "standardized_cost_adv": true, - "standardized_rew_adv": true - }, - "check_freq": 25, - "clip": 0.2, - "cost_gamma": 1.0, - "critic_iters": 1, - "critic_lr": 0.0003, - "critic_norm_coeff": 0.001, - "data_dir": "./runs", - "device": "cpu", - "device_id": 0, - "entropy_coef": 0.0, - "env_cfgs": { - "async_env": true, - "max_len": 100, - "normalized_cost": true, - "normalized_obs": true, - "normalized_rew": true, - "num_envs": 1, - "num_threads": 20, - "safety_budget": 25, - "saute_gamma": 0.9997, - "scale_safety_budget": true, - "unsafe_reward": -0.1 - }, - "env_id": "SafetyHumanoidVelocity-v4", - "epochs": 1, - "exp_name": "SafetyHumanoidVelocity-v4/PPOSaute", - "exploration_noise_anneal": false, - "kl_early_stopping": true, - "linear_lr_decay": true, - "max_ep_len": 1000, - "max_grad_norm": 40, - "model_cfgs": { - "ac_kwargs": { - "pi": { - "activation": "tanh", - "clip_action": false, - "hidden_sizes": [ - 64, - 64 - ], - "output_activation": "identity", - "scale_action": false, - "std_init": 1.0, - "std_learning": true - }, - "val": { - "activation": "tanh", - "hidden_sizes": [ - 64, - 64 - ], - "num_critics": 1 - } - }, - "actor_type": "gaussian", - "shared_weights": false, - "weight_initialization_mode": "kaiming_uniform" - }, - "num_mini_batches": 64, - "penalty_param": 0.0, - "save_freq": 50, - "seed": 0, - "steps_per_epoch": 1000, - "target_kl": 0.02, - "use_cost": false, - "use_critic_norm": true, - "use_max_grad_norm": true, - "wrapper_type": "SauteWrapper" -} diff --git a/tests/saved_policy/PPOSaute/seed-000/tb/events.out.tfevents.1675082213.user.1133609.32 b/tests/saved_policy/PPOSaute/seed-000/tb/events.out.tfevents.1675082213.user.1133609.32 deleted file mode 100644 index 292a31f6f..000000000 Binary files a/tests/saved_policy/PPOSaute/seed-000/tb/events.out.tfevents.1675082213.user.1133609.32 and /dev/null differ diff --git a/tests/saved_policy/PPOSaute/seed-000/torch_save/model None.pt b/tests/saved_policy/PPOSaute/seed-000/torch_save/model None.pt deleted file mode 100644 index 67f7a60f4..000000000 Binary files a/tests/saved_policy/PPOSaute/seed-000/torch_save/model None.pt and /dev/null differ diff --git a/tests/saved_policy/PPOSimmerQ/seed-000/config.json b/tests/saved_policy/PPOSimmerQ/seed-000/config.json deleted file mode 100644 index daa49ee23..000000000 --- a/tests/saved_policy/PPOSimmerQ/seed-000/config.json +++ /dev/null @@ -1,92 +0,0 @@ -{ - "actor_iters": 10, - "actor_lr": 0.0003, - "batch_size": 10000, - "buffer_cfgs": { - "adv_estimation_method": "gae", - "gamma": 0.99, - "lam": 0.95, - "lam_c": 0.95, - "standardized_cost_adv": true, - "standardized_rew_adv": true - }, - "check_freq": 25, - "clip": 0.2, - "cost_gamma": 1.0, - "critic_iters": 1, - "critic_lr": 0.0003, - "critic_norm_coeff": 0.001, - "data_dir": "./runs", - "device": "cpu", - "device_id": 0, - "entropy_coef": 0.0, - "env_cfgs": { - "async_env": true, - "controller_cfgs": { - "act_dim": 3, - "epsilon": 0.8, - "q_lr": 0.1, - "state_dim": 5, - "tau": 0.95, - "threshold": 2 - }, - "lower_budget": 15, - "max_len": 100, - "normalized_cost": true, - "normalized_obs": true, - "normalized_rew": true, - "num_envs": 1, - "num_threads": 20, - "scale_safety_budget": false, - "simmer_controller": "Q", - "simmer_gamma": 0.9997, - "unsafe_reward": -0.1, - "upper_budget": 25 - }, - "env_id": "SafetyHumanoidVelocity-v4", - "epochs": 1, - "exp_name": "SafetyHumanoidVelocity-v4/PPOSimmerQ", - "exploration_noise_anneal": false, - "kl_early_stopping": true, - "linear_lr_decay": true, - "max_ep_len": 1000, - "max_grad_norm": 40, - "model_cfgs": { - "ac_kwargs": { - "pi": { - "activation": "tanh", - "clip_action": false, - "hidden_sizes": [ - 64, - 64 - ], - "output_activation": "identity", - "scale_action": false, - "std_init": 1.0, - "std_learning": true - }, - "val": { - "activation": "tanh", - "hidden_sizes": [ - 64, - 64 - ], - "num_critics": 1 - } - }, - "actor_type": "gaussian", - "shared_weights": false, - "weight_initialization_mode": "kaiming_uniform" - }, - "num_mini_batches": 64, - "penalty_param": 0.0, - "save_freq": 100, - "seed": 0, - "standardized_obs": true, - "steps_per_epoch": 1000, - "target_kl": 0.02, - "use_cost": false, - "use_critic_norm": true, - "use_max_grad_norm": true, - "wrapper_type": "SimmerWrapper" -} diff --git a/tests/saved_policy/PPOSimmerQ/seed-000/tb/events.out.tfevents.1675082216.user.1133609.34 b/tests/saved_policy/PPOSimmerQ/seed-000/tb/events.out.tfevents.1675082216.user.1133609.34 deleted file mode 100644 index f395ffa02..000000000 Binary files a/tests/saved_policy/PPOSimmerQ/seed-000/tb/events.out.tfevents.1675082216.user.1133609.34 and /dev/null differ diff --git a/tests/saved_policy/PPOSimmerQ/seed-000/torch_save/model None.pt b/tests/saved_policy/PPOSimmerQ/seed-000/torch_save/model None.pt deleted file mode 100644 index 67f7a60f4..000000000 Binary files a/tests/saved_policy/PPOSimmerQ/seed-000/torch_save/model None.pt and /dev/null differ diff --git a/tests/saved_policy/SAC/seed-000/config.json b/tests/saved_policy/SAC/seed-000/config.json deleted file mode 100644 index a236abf3d..000000000 --- a/tests/saved_policy/SAC/seed-000/config.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "actor_lr": 0.0001, - "alpha": 0.2, - "alpha_gamma": 1.0, - "alpha_lr": 0.0003, - "auto_alpha": true, - "check_freq": 25, - "cost_gamma": 1.0, - "cost_limit_decay": false, - "critic_lr": 0.0001, - "critic_norm_coeff": 0.001, - "data_dir": "./runs", - "device": "cpu", - "device_id": 0, - "end_epoch": 100, - "env_cfgs": { - "async_env": true, - "max_len": 100, - "normalized_cost": false, - "normalized_obs": true, - "normalized_rew": false, - "num_envs": 1, - "num_threads": 20 - }, - "env_id": "SafetyHumanoidVelocity-v4", - "epochs": 1, - "exp_name": "SafetyHumanoidVelocity-v4/SAC", - "exploration_noise_anneal": false, - "gamma": 0.99, - "init_cost_limit": 25.0, - "kl_early_stopping": false, - "linear_lr_decay": true, - "max_ep_len": 1000, - "max_grad_norm": 40, - "model_cfgs": { - "ac_kwargs": { - "pi": { - "activation": "relu", - "clip_action": true, - "hidden_sizes": [ - 64, - 64 - ], - "output_activation": "tanh", - "scale_action": true, - "std_init": 1.0, - "std_learning": true - }, - "val": { - "activation": "relu", - "hidden_sizes": [ - 64, - 64 - ], - "num_critics": 2 - } - }, - "actor_type": "gaussian_stdnet", - "shared_weights": false, - "weight_initialization_mode": "kaiming_uniform" - }, - "num_test_episodes": 10, - "polyak": 0.999, - "replay_buffer_cfgs": { - "batch_size": 1024, - "size": 100000 - }, - "reward_penalty": false, - "save_freq": 10, - "seed": 5, - "start_steps": 10000, - "steps_per_epoch": 1000, - "target_cost_limit": 25.0, - "update_after": 999, - "update_every": 1, - "use_cost": false, - "use_critic_norm": false, - "use_max_grad_norm": false, - "wrapper_type": "CMDPWrapper" -} diff --git a/tests/saved_policy/SAC/seed-000/tb/events.out.tfevents.1675082020.user.1133609.6 b/tests/saved_policy/SAC/seed-000/tb/events.out.tfevents.1675082020.user.1133609.6 deleted file mode 100644 index 2a386757b..000000000 Binary files a/tests/saved_policy/SAC/seed-000/tb/events.out.tfevents.1675082020.user.1133609.6 and /dev/null differ diff --git a/tests/saved_policy/SAC/seed-000/torch_save/model None.pt b/tests/saved_policy/SAC/seed-000/torch_save/model None.pt deleted file mode 100644 index 3867944fc..000000000 Binary files a/tests/saved_policy/SAC/seed-000/torch_save/model None.pt and /dev/null differ diff --git a/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/config.json b/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/config.json new file mode 100644 index 000000000..4ce040b2e --- /dev/null +++ b/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/config.json @@ -0,0 +1,74 @@ +{ + "seed": 0, + "train_cfgs": { + "device": "cpu", + "torch_threads": 16, + "vector_env_nums": 1, + "parallel": 1, + "total_steps": 1024000, + "algo": "PPO", + "env_id": "SafetyPointGoal1-v0", + "epochs": 1000 + }, + "algo_cfgs": { + "update_cycle": 1024, + "update_iters": 40, + "batch_size": 64, + "target_kl": 0.02, + "entropy_coef": 0.0, + "reward_normalize": true, + "cost_normalize": true, + "obs_normalize": true, + "kl_early_stop": true, + "use_max_grad_norm": true, + "max_grad_norm": 40.0, + "use_critic_norm": true, + "critic_norm_coef": 0.001, + "gamma": 0.99, + "cost_gamma": 0.99, + "lam": 0.95, + "lam_c": 0.95, + "clip": 0.2, + "adv_estimation_method": "gae", + "standardized_rew_adv": true, + "standardized_cost_adv": true, + "penalty_coef": 0.0, + "use_cost": false + }, + "logger_cfgs": { + "use_wandb": false, + "wandb_project": "omnisafe", + "use_tensorboard": true, + "save_model_freq": 100, + "log_dir": "./runs", + "window_lens": 100 + }, + "model_cfgs": { + "weight_initialization_mode": "kaiming_uniform", + "actor_type": "gaussian_learning", + "linear_lr_decay": true, + "exploration_noise_anneal": false, + "std_range": [ + 0.5, + 0.1 + ], + "actor": { + "hidden_sizes": [ + 64, + 64 + ], + "activation": "tanh", + "lr": 0.0003 + }, + "critic": { + "hidden_sizes": [ + 64, + 64 + ], + "activation": "tanh", + "lr": 0.0003 + } + }, + "exp_name": "PPO-(SafetyPointGoal1-v0)", + "env_id": "SafetyPointGoal1-v0" +} diff --git a/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/torch_save/epoch-0.pt b/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/torch_save/epoch-0.pt new file mode 100644 index 000000000..4b6e5e6e6 Binary files /dev/null and b/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/torch_save/epoch-0.pt differ diff --git a/tests/saved_source/benchmark_config.yaml b/tests/saved_source/benchmark_config.yaml new file mode 100644 index 000000000..f4588e97d --- /dev/null +++ b/tests/saved_source/benchmark_config.yaml @@ -0,0 +1,31 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +algo: + ['PolicyGradient', 'NaturalPG'] +env_id: + ['SafetyAntVelocity-v4'] +logger_cfgs:use_wandb: + [False] +train_cfgs:vector_env_nums: + [2] +train_cfgs:torch_threads: + [1] +train_cfgs:total_steps: + 1024 +algo_cfgs:update_cycle: + 512 +seed: + [0] diff --git a/tests/saved_source/train_config.yaml b/tests/saved_source/train_config.yaml new file mode 100644 index 000000000..077e59051 --- /dev/null +++ b/tests/saved_source/train_config.yaml @@ -0,0 +1,26 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +algo: + PPOLag +env_id: + SafetyAntVelocity-v4 +train_cfgs: + total_steps: + 1024 + vector_env_nums: 1 +algo_cfgs: + update_cycle: + 512 diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 000000000..4f4b1e1eb --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,81 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from typer.testing import CliRunner + +from omnisafe import app + + +runner = CliRunner() +base_path = os.path.dirname(os.path.abspath(__file__)) + + +def test_benchmark(): + result = runner.invoke( + app, + [ + 'benchmark', + 'test_benchmark', + '2', + os.path.join(base_path, './saved_source/benchmark_config.yaml'), + ], + ) + assert result.exit_code == 0 + + +def test_train(): + result = runner.invoke( + app, + [ + 'train', + '--algo', + 'PPO', + '--total-steps', + '1024', + '--vector-env-nums', + '1', + '--custom-cfgs', + 'algo_cfgs:update_cycle', + '--custom-cfgs', + '512', + ], + ) + assert result.exit_code == 0 + + +def test_train_config(): + result = runner.invoke( + app, ['train-config', os.path.join(base_path, './saved_source/train_config.yaml')] + ) + assert result.exit_code == 0 + + +def test_eval(): + result = runner.invoke( + app, + [ + 'eval', + os.path.join(base_path, './saved_source/PPO-{SafetyPointGoal1-v0}'), + '--num-episode', + '1', + '--width', + '1', + '--height', + '1', + ], + ) + assert result.exit_code == 0