Skip to content

Commit

Permalink
feat: update architecture of config.yaml (PKU-Alignment#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmsn-2077 committed Mar 14, 2023
1 parent ae995b0 commit a7005ea
Show file tree
Hide file tree
Showing 72 changed files with 2,337 additions and 6,019 deletions.
24 changes: 12 additions & 12 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ jobs:
run: |
make addlicense
- name: mypy
run: |
make mypy
# - name: mypy
# run: |
# make mypy

- name: Install dependencies
run: |
Expand All @@ -80,15 +80,15 @@ jobs:
# TODO: enable this when ready
# - name: Run tests and collect coverage
# run: |
# pytest tests --ignore-glob='*profile.py' --cov=omnisafe --cov-report=xml
# --cov-report=term-missing --durations=0 -v --color=yes
# run: |
# pytest tests --ignore-glob='*profile.py' --cov=omnisafe --cov-report=xml
# --cov-report=term-missing --durations=0 -v --color=yes

# TODO: enable this when ready
# - name: Upload coverage reports to Codecov
# run: |
# # Replace `linux` below with the appropriate OS
# # Options are `alpine`, `linux`, `macos`, `windows`
# curl -Os https://uploader.codecov.io/latest/linux/codecov
# chmod +x codecov
# ./codecov -t ${CODECOV_TOKEN=634594d3-0416-4632-ab6a-3bf34a8c0af3}
# run: |
# # Replace `linux` below with the appropriate OS
# # Options are `alpine`, `linux`, `macos`, `windows`
# curl -Os https://uploader.codecov.io/latest/linux/codecov
# chmod +x codecov
# ./codecov -t ${CODECOV_TOKEN=634594d3-0416-4632-ab6a-3bf34a8c0af3}
18 changes: 12 additions & 6 deletions examples/benchmarks/run_experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@ def train(

if __name__ == '__main__':
eg = ExperimentGrid(exp_name='Safety_Gymnasium_Goal')
eg.add('algo', ['PPO', 'PPOLag'])
base_policy = ['PolicyGradient', 'NaturalPG', 'TRPO', 'PPO']
naive_lagrange_policy = ['PPOLag', 'TRPOLag', 'RCPO', 'OnCRPO', 'PDO']
first_order_policy = ['CUP', 'FOCOPS']
second_order_policy = ['CPO', 'PCPO']
eg.add('algo', base_policy + naive_lagrange_policy + first_order_policy + second_order_policy)
eg.add('env_id', ['SafetyPointGoal1-v0'])
eg.add('epochs', 1)
eg.add('actor_lr', [0.001, 0.003, 0.004], 'lr', True)
eg.add('actor_iters', [1, 2], 'ac_iters', True)
eg.add('seed', [0, 5, 10])
eg.run(train, num_pool=10)
eg.add('logger_cfgs:use_wandb', [True])
eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming'])
# eg.add('train_cfgs:total_steps', 2000)
# eg.add('algo_cfgs:update_cycle', 1000)
# eg.add('train_cfgs:vector_env_nums', 1)
eg.add('seed', [0])
eg.run(train, num_pool=13)
18 changes: 16 additions & 2 deletions examples/train_from_custom_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,23 @@
metavar='N',
help='Number of paralleled progress for calculations.',
)
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_cfgs = {
'train_cfgs': {
'total_steps': 1000,
},
'algo_cfgs': {
'update_cycle': 1000,
'update_iters': 1,
},
'logger_cfgs': {
'use_wandb': False,
},
'env_cfgs': {
'vector_env_nums': 1,
},
}
args, _ = parser.parse_known_args()
agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_dict, parallel=args.parallel)
agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs, parallel=args.parallel)
agent.learn()

# obs = env.reset()
Expand Down
47 changes: 40 additions & 7 deletions examples/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse

import omnisafe
from omnisafe.utils.tools import custom_cfgs_to_dict, update_dic


if __name__ == '__main__':
Expand All @@ -26,32 +27,64 @@
type=str,
metavar='ALGO',
default='PPOLag',
help='Algorithm to train',
help='algorithm to train',
choices=omnisafe.ALGORITHMS['all'],
)
parser.add_argument(
'--env-id',
type=str,
metavar='ENV',
default='SafetyPointGoal1-v0',
help='The name of test environment',
help='the name of test environment',
)
parser.add_argument(
'--parallel',
default=1,
type=int,
metavar='N',
help='Number of paralleled progress for calculations.',
help='number of paralleled progress for calculations.',
)
parser.add_argument(
'--total-steps',
type=int,
default=1638400,
metavar='STEPS',
help='total number of steps to train for algorithm',
)
parser.add_argument(
'--device',
type=str,
default='cpu',
metavar='DEVICES',
help='device to use for training',
)
parser.add_argument(
'--vector-env-nums',
type=int,
default=16,
metavar='VECTOR-ENV',
help='number of vector envs to use for training',
)
parser.add_argument(
'--torch-threads',
type=int,
default=16,
metavar='THREADS',
help='number of threads to use for torch',
)
args, unparsed_args = parser.parse_known_args()
keys = [k[2:] for k in unparsed_args[0::2]]
values = list(unparsed_args[1::2])
unparsed_dict = dict(zip(keys, values))
# env = omnisafe.Env(args.env_id)
unparsed_args = dict(zip(keys, values))

custom_cfgs = {}
for k, v in unparsed_args.items():
update_dic(custom_cfgs, custom_cfgs_to_dict(k, v))

agent = omnisafe.Agent(
args.algo,
args.env_id,
parallel=args.parallel,
custom_cfgs=unparsed_dict,
train_terminal_cfgs=vars(args),
custom_cfgs=custom_cfgs,
)
agent.learn()
6 changes: 3 additions & 3 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __init__( # pylint: disable=too-many-arguments
self._env_id = env_id
self._env = make(env_id, num_envs=num_envs)
self._wrapper(
obs_normalize=cfgs.obs_normalize,
reward_normalize=cfgs.reward_normalize,
cost_normalize=cfgs.cost_normalize,
obs_normalize=cfgs.algo_cfgs.obs_normalize,
reward_normalize=cfgs.algo_cfgs.reward_normalize,
cost_normalize=cfgs.algo_cfgs.cost_normalize,
)
self._env.set_seed(seed)

Expand Down
2 changes: 1 addition & 1 deletion omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def roll_out( # pylint: disable=too-many-locals

self._log_value(reward=reward, cost=cost, info=info)

if self._cfgs.use_cost:
if self._cfgs.algo_cfgs.use_cost:
logger.store(**{'Value/cost': value_c})
logger.store(**{'Value/reward': value_r})

Expand Down
8 changes: 1 addition & 7 deletions omnisafe/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from omnisafe.algorithms.base_algo import BaseAlgo

# On-Policy Safe
from omnisafe.algorithms.on_policy import ( # PPOLagSimmerPid,; PPOLagSimmerQ,; PPOSimmerPid,; PPOSimmerQ,
from omnisafe.algorithms.on_policy import (
CPO,
CUP,
FOCOPS,
Expand All @@ -30,17 +30,11 @@
PPO,
RCPO,
TRPO,
CPPOPid,
NaturalPG,
OnCRPO,
PolicyGradient,
PPOEarlyTerminated,
PPOLag,
PPOLagEarlyTerminated,
PPOLagSaute,
PPOSaute,
TRPOLag,
TRPOPid,
)


Expand Down
72 changes: 44 additions & 28 deletions omnisafe/algorithms/algo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Implementation of the AlgoWrapper Class."""

import difflib
import os
import sys
from typing import Any, Dict, Optional

Expand All @@ -25,7 +24,7 @@

from omnisafe.algorithms import ALGORITHM2TYPE, ALGORITHMS, registry
from omnisafe.utils import distributed
from omnisafe.utils.config import get_default_kwargs_yaml
from omnisafe.utils.config import check_all_configs, get_default_kwargs_yaml


class AlgoWrapper:
Expand All @@ -35,66 +34,83 @@ def __init__(
self,
algo: str,
env_id: str,
parallel: int = 1,
train_terminal_cfgs: Optional[Dict[str, Any]] = None,
custom_cfgs: Optional[Dict[str, Any]] = None,
):
self.algo = algo
self.parallel = parallel
self.env_id = env_id
# algo_type will set in _init_checks()
self.algo_type: str

self.train_terminal_cfgs = train_terminal_cfgs
self.custom_cfgs = custom_cfgs
self.evaluator = None
self.cfgs = self._init_config()
self._init_checks()

def _init_config(self):
"""Init config."""
assert self.algo in ALGORITHMS['all'], (
f"{self.algo} doesn't exist. "
f"Did you mean {difflib.get_close_matches(self.algo, ALGORITHMS['all'], n=1)[0]}?"
)
self.algo_type = ALGORITHM2TYPE.get(self.algo, '')
if self.algo_type is None or self.algo_type == '':
raise ValueError(f'{self.algo} is not supported!')
if self.algo_type in ['off-policy', 'model-based']:
assert (
self.train_terminal_cfgs.parallel == 1
), 'off-policy or model-based only support parallel==1!'
cfgs = get_default_kwargs_yaml(self.algo, self.env_id, self.algo_type)

# update the cfgs from custom configurations
if self.custom_cfgs:
cfgs.recurisve_update(self.custom_cfgs)
# update the cfgs from custom terminal configurations
if self.train_terminal_cfgs:
cfgs.train_cfgs.recurisve_update(self.train_terminal_cfgs)

# the exp_name format is PPO-<SafetyPointGoal1-v0>-
exp_name = f'{self.algo}-<{self.env_id}>'
cfgs.recurisve_update({'exp_name': exp_name, 'env_id': self.env_id})
cfgs.train_cfgs.recurisve_update(
{'epochs': cfgs.train_cfgs.total_steps // cfgs.algo_cfgs.update_cycle}
)
return cfgs

def _init_checks(self):
"""Init checks."""
assert isinstance(self.algo, str), 'algo must be a string!'
assert isinstance(self.parallel, int), 'parallel must be an integer!'
assert self.parallel > 0, 'parallel must be greater than 0!'
assert isinstance(self.cfgs.train_cfgs.parallel, int), 'parallel must be an integer!'
assert self.cfgs.train_cfgs.parallel > 0, 'parallel must be greater than 0!'
assert (
isinstance(self.custom_cfgs, dict) or self.custom_cfgs is None
), 'custom_cfgs must be a dict!'
assert self.algo in ALGORITHMS['all'], (
f"{self.algo} doesn't exist. "
f"Did you mean {difflib.get_close_matches(self.algo, ALGORITHMS['all'], n=1)[0]}?"
)
assert self.env_id in safe_registry, (
f"{self.env_id} doesn't exist. "
f'Did you mean {difflib.get_close_matches(self.env_id, safe_registry, n=1)[0]}?'
)
self.algo_type = ALGORITHM2TYPE.get(self.algo, '')
if self.algo_type is None or self.algo_type == '':
raise ValueError(f'{self.algo} is not supported!')
if self.algo_type in ['off-policy', 'model-based']:
assert self.parallel == 1, 'off-policy or model-based only support parallel==1!'

def learn(self):
"""Agent Learning."""
# Use number of physical cores as default.
# If also hardware threading CPUs should be used
# enable this by the use_number_of_threads=True
physical_cores = psutil.cpu_count(logical=False)
use_number_of_threads = bool(self.parallel > physical_cores)

cfgs = get_default_kwargs_yaml(self.algo, self.env_id, self.algo_type)
exp_name = os.path.join(self.env_id, self.algo)
cfgs.recurisve_update({'exp_name': exp_name, 'env_id': self.env_id})
if self.custom_cfgs is not None:
cfgs.recurisve_update(self.custom_cfgs)

# check_all_configs(cfgs, self.algo_type)

torch.set_num_threads(cfgs.num_threads)
use_number_of_threads = bool(self.cfgs.train_cfgs.parallel > physical_cores)

check_all_configs(self.cfgs, self.algo_type)
torch.set_num_threads(self.cfgs.train_cfgs.torch_threads)
if distributed.fork(
self.parallel, use_number_of_threads=use_number_of_threads, device=cfgs.device
self.cfgs.train_cfgs.parallel,
use_number_of_threads=use_number_of_threads,
device=self.cfgs.train_cfgs.device,
):
# Re-launches the current script with workers linked by MPI
sys.exit()
agent = registry.get(self.algo)(
env_id=self.env_id,
cfgs=cfgs,
cfgs=self.cfgs,
)
ep_ret, ep_cost, ep_len = agent.learn()
return ep_ret, ep_len, ep_cost
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/algorithms/base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self, env_id: str, cfgs: Config) -> None:
self._seed = cfgs.seed + distributed.get_rank() * 1000
seed_all(self._seed)

assert hasattr(cfgs, 'device'), 'Please specify the device in the config file.'
self._device = torch.device(self._cfgs.device)
assert hasattr(cfgs.train_cfgs, 'device'), 'Please specify the device in the config file.'
self._device = torch.device(self._cfgs.train_cfgs.device)

distributed.setup_distributed()

Expand Down
Loading

0 comments on commit a7005ea

Please sign in to comment.