Skip to content

Commit

Permalink
feat: support command line interfaces for omnisafe (#144)
Browse files Browse the repository at this point in the history
Co-authored-by: borong <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ruiyang sun <[email protected]>
Co-authored-by: zmsn-2077 <[email protected]>
Co-authored-by: friedmainfunction <[email protected]>
Co-authored-by: Gaiejj <[email protected]>
Co-authored-by: zmsn-2077 <[email protected]>
Co-authored-by: Ruiyang Sun <[email protected]>
Co-authored-by: Jiayi Zhou <[email protected]>
Co-authored-by: 1Asan <[email protected]>
fix(algo): fix no return in algo_wrapper::learn (#122)
fix(logger, wrapper): support csv file and velocity tasks (#131)
fix typo. (#134)
fix(ppo): fix entropy loss (#135)
fix bugs (#136)
fix: support new config for exp_grid (#142)
fix(rollout, exp_grid): fix logdir path conflict (#145)
fix(on-policy): fix the second order algorithms performance (#147)
  • Loading branch information
10 people committed Mar 26, 2023
1 parent e5bf84c commit d5e2814
Show file tree
Hide file tree
Showing 39 changed files with 654 additions and 549 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ good-names=i,
id,
e,
f,
eg,

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ The supported interface algorithms currently include:
- [X] **[AAAI 2020]** [IPO: Interior-point Policy Optimization under Constraints (IPO)](https://arxiv.org/abs/1910.09615)
- [X] **[ICLR 2020]** [Projection-Based Constrained Policy Optimization (PCPO)](https://openreview.net/forum?id=rke3TJrtPS)
- [X] **[ICML 2021]** [CRPO: A New Approach for Safe Reinforcement Learning with Convergence Guarantee](https://arxiv.org/abs/2011.05869)
- [x] **[IJCAI 2022]** [Penalized Proximal Policy Optimization for Safe Reinforcement Learning(P3O)](https://arxiv.org/pdf/2205.11814.pdf)

#### Off-Policy Safe

Expand Down Expand Up @@ -147,6 +148,29 @@ cd examples
python train_policy.py --algo PPOLag --env-id SafetyPointGoal1-v0 --parallel 1 --total-steps 1024000 --device cpu --vector-env-nums 1 --torch-threads 1
```

#### Try with CLI

```bash
pip install omnisafe

omnisafe --help # Ask for help

omnisafe [command] --help # Ask for command specific help

# Quick benchmarking for your research, just specify: 1.exp_name, 2.num_pool(how much processes are concurrent), 3.path of the config file(refer to omnisafe/examples/benchmarks for format)
omnisafe benchmark test_benchmark, 2, "./saved_source/benchmark_config.yaml"

# Quick evaluating and rendering your trained policy, just specify: 1.path of algorithm which you trained
omnisafe eval ./saved_source/PPO-{SafetyPointGoal1-v0}, "--num-episode", "1"

# Quick training some algorithms to validate your thoughts
# Note: use `key1:key2`, your can select key of hyperparameters which are recursively contained, and use `--custom-cfgs`, you can add custom cfgs via CLI
omnisafe train --algo PPO --total-steps 1024 --vector-env-nums 1 --custom-cfgs algo_cfgs:update_cycle --custom-cfgs 512

# Quick training some algorithms via a saved config file, the format is as same as default format
omnisafe train-config "./saved_source/train_config.yaml"
```

**algo:**
Type | Name
---------------| ----------------------------------------------
Expand Down Expand Up @@ -205,7 +229,8 @@ More information about environments, please refer to [Safety Gymnasium](https://
--------------------------------------------------------------------------------

## Getting Started

#### Important Hints
- `train_cfgs:torch_threads`is especialy important for trainning speed, and is varying with users' machine, this value shouldn't be too small or too large.
### 1. Run Agent from preset yaml file

```python
Expand Down
31 changes: 31 additions & 0 deletions examples/benchmarks/example_cli_benchmark_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

algo:
['PolicyGradient', 'NaturalPG']
env_id:
['SafetyAntVelocity-v4']
logger_cfgs:use_wandb:
[False]
train_cfgs:vector_env_nums:
[2]
train_cfgs:torch_threads:
[1]
train_cfgs:total_steps:
1024
algo_cfgs:update_cycle:
512
seed:
[0]
32 changes: 11 additions & 21 deletions examples/benchmarks/run_experiment_grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,6 @@
import os
import sys

import torch

import omnisafe
from omnisafe.common.experiment_grid import ExperimentGrid
from omnisafe.typing import NamedTuple, Tuple
Expand All @@ -39,16 +37,10 @@ def train(
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
print(f'exp-x: {exp_id} is training...')
USE_REDIRECTION = True
if USE_REDIRECTION:
if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']):
os.makedirs(custom_cfgs['logger_cfgs']['log_dir'])
sys.stdout = open(
f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8'
)
sys.stderr = open(
f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8'
)
if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']):
os.makedirs(custom_cfgs['logger_cfgs']['log_dir'])
sys.stdout = open(f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8')
sys.stderr = open(f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8')
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs)
reward, cost, ep_len = agent.learn()
return reward, cost, ep_len
Expand All @@ -58,16 +50,14 @@ def train(
eg = ExperimentGrid(exp_name='Safety_Gymnasium_Goal')
base_policy = ['PolicyGradient', 'NaturalPG', 'TRPO', 'PPO']
naive_lagrange_policy = ['PPOLag', 'TRPOLag', 'RCPO', 'OnCRPO', 'PDO']
first_order_policy = ['CUP', 'FOCOPS']
first_order_policy = ['CUP', 'FOCOPS', 'P3O']
second_order_policy = ['CPO', 'PCPO']
eg.add('algo', base_policy + naive_lagrange_policy + first_order_policy + second_order_policy)
eg.add('env_id', ['SafetyAntVelocity-v4'])
eg.add('logger_cfgs:use_wandb', [False])
eg.add('num_envs', [16])
eg.add('num_threads', [1])
# eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming'])
# eg.add('train_cfgs:total_steps', 2000)
# eg.add('algo_cfgs:update_cycle', 1000)
# eg.add('train_cfgs:vector_env_nums', 1)
eg.add('train_cfgs:vector_env_nums', [4])
eg.add('train_cfgs:torch_threads', [1])
eg.add('seed', [0])
eg.run(train, num_pool=13)
# total experiment num must can be divided by num_pool
# meanwhile, users should decide this value according to their machine
eg.run(train, num_pool=14)
6 changes: 2 additions & 4 deletions examples/evaluate_saved_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 OmniSafe Team. All Rights Reserved.
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,10 +22,8 @@
# Just fill your experiment's log directory in here.
# Such as: ~/omnisafe/examples/runs/PPOLag-<SafetyPointGoal1-v0>/seed-000-2023-03-07-20-25-48
LOG_DIR = ''
play = True
save_replay = True
if __name__ == '__main__':
evaluator = omnisafe.Evaluator(play=play, save_replay=save_replay)
evaluator = omnisafe.Evaluator(render_mode='rgb_array')
for item in os.scandir(os.path.join(LOG_DIR, 'torch_save')):
if item.is_file() and item.name.split('.')[-1] == 'pt':
evaluator.load_saved(
Expand Down
2 changes: 1 addition & 1 deletion examples/train_from_custom_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 5 additions & 1 deletion omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict, Optional

import torch
from rich.progress import track

from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.common.buffer import VectorOnPolicyBuffer
Expand Down Expand Up @@ -56,7 +57,10 @@ def roll_out( # pylint: disable=too-many-locals
self._reset_log()

obs, _ = self.reset()
for step in range(steps_per_epoch):
for step in track(
range(steps_per_epoch),
description=f'Processing rollout for epoch: {logger.current_epoch}...',
):
act, value_r, value_c, logp = agent.step(obs)
next_obs, reward, cost, terminated, truncated, info = self.step(act)

Expand Down
9 changes: 7 additions & 2 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from omnisafe.envs import support_envs
from omnisafe.utils import distributed
from omnisafe.utils.config import check_all_configs, get_default_kwargs_yaml
from omnisafe.utils.tools import recursive_check_config


class AlgoWrapper:
Expand Down Expand Up @@ -66,13 +67,17 @@ def _init_config(self):

# update the cfgs from custom configurations
if self.custom_cfgs:
recursive_check_config(self.custom_cfgs, cfgs, exclude_keys=('algo', 'env_id'))
cfgs.recurisve_update(self.custom_cfgs)
# update the cfgs from custom terminal configurations
if self.train_terminal_cfgs:
recursive_check_config(
self.train_terminal_cfgs, cfgs.train_cfgs, exclude_keys=('algo', 'env_id')
)
cfgs.train_cfgs.recurisve_update(self.train_terminal_cfgs)

# the exp_name format is PPO-(SafetyPointGoal1-v0)-
exp_name = f'{self.algo}-({self.env_id})'
# the exp_name format is PPO-{SafetyPointGoal1-v0}-
exp_name = f'{self.algo}-{{{self.env_id}}}'
cfgs.recurisve_update({'exp_name': exp_name, 'env_id': self.env_id})
cfgs.train_cfgs.recurisve_update(
{'epochs': cfgs.train_cfgs.total_steps // cfgs.algo_cfgs.update_cycle}
Expand Down
9 changes: 5 additions & 4 deletions omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@

import torch
import torch.nn as nn
from rich.progress import track
from torch.utils.data import DataLoader, TensorDataset

from omnisafe.adapter import OnPolicyAdapter
Expand Down Expand Up @@ -219,7 +220,7 @@ def _update(self) -> None:
shuffle=True,
)

for i in range(self._cfgs.algo_cfgs.update_iters):
for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'):
for (
obs,
act,
Expand All @@ -245,12 +246,12 @@ def _update(self) -> None:
kl = distributed.dist_avg(kl)

if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl:
self._logger.log(f'Early stopping at iter {i} due to reaching max kl')
self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl')
break

self._logger.store(
**{
'Train/StopIter': i + 1,
'Train/StopIter': i + 1, # pylint: disable=undefined-loop-variable
'Value/Adv': adv_r.mean().item(),
'Train/KL': kl,
}
Expand Down
9 changes: 5 additions & 4 deletions omnisafe/algorithms/on_policy/first_order/cup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
"""Implementation of the CUP algorithm."""

import torch
from rich.progress import track
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset

Expand Down Expand Up @@ -142,7 +143,7 @@ def _update(self) -> None:
shuffle=True,
)

for i in range(self._cfgs.algo_cfgs.update_iters):
for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'):
for obs, act, logp, adv_c, old_mean, old_std in dataloader:
self._p_dist = Normal(old_mean, old_std)
loss_cost, info = self._loss_pi_cost(obs, act, logp, adv_c)
Expand All @@ -166,15 +167,15 @@ def _update(self) -> None:
kl = distributed.dist_avg(kl)

if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl:
self._logger.log(f'Early stopping at iter {i} due to reaching max kl')
self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl')
break

self._logger.store(
**{
'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.item(),
'Train/MaxRatio': self._max_ratio,
'Train/MinRatio': self._min_ratio,
'Train/SecondStepStopIter': i + 1,
'Train/SecondStepStopIter': i + 1, # pylint: disable=undefined-loop-variable
'Train/SecondStepEntropy': info['entropy'],
'Train/SecondStepPolicyRatio': info['ratio'],
}
Expand Down
9 changes: 5 additions & 4 deletions omnisafe/algorithms/on_policy/first_order/focops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
from typing import Dict, Tuple

import torch
from rich.progress import track
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset

Expand Down Expand Up @@ -108,7 +109,7 @@ def _update(self) -> None:
shuffle=True,
)

for i in range(self._cfgs.algo_cfgs.update_iters):
for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'):
for (
obs,
act,
Expand Down Expand Up @@ -138,12 +139,12 @@ def _update(self) -> None:
kl = distributed.dist_avg(kl)

if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl:
self._logger.log(f'Early stopping at iter {i} due to reaching max kl')
self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl')
break

self._logger.store(
**{
'Train/StopIter': i + 1,
'Train/StopIter': i + 1, # pylint: disable=undefined-loop-variable
'Value/Adv': adv_r.mean().item(),
'Train/KL': kl,
'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier,
Expand Down
Loading

0 comments on commit d5e2814

Please sign in to comment.