Skip to content

Commit

Permalink
fix: support new config for exp_grid (#142)
Browse files Browse the repository at this point in the history
Co-authored-by: borong <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 9, 2023
1 parent 989ff4b commit 69086ab
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 21 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ import omnisafe

env = 'SafetyPointGoal1-v0'

custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('PPOLag', env, custom_cfgs=custom_dict)
agent.learn()
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/BaseRL/PPO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ Quick start
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('PPO', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/BaseRL/TRPO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ Quick start
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('TRPO', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/Documentation/Introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ We give an example below:
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('CPO', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/SafeRL/CPO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ Quick start
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('CPO', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/SafeRL/FOCOPS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ Quick start
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('FOCOPS', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/SafeRL/Lag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ Quick start
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('PPOLag', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/SafeRL/PCPO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ Quick start
env = omnisafe.Env('SafetyPointGoal1-v0')
custom_dict = {'epochs': 1, 'data_dir': './runs'}
custom_dict = {'epochs': 1, 'log_dir': './runs'}
agent = omnisafe.Agent('PCPO', env, custom_cfgs=custom_dict)
agent.learn()
Expand Down
23 changes: 14 additions & 9 deletions examples/benchmarks/run_experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def train(
exp_id: str, algo: str, env_id: str, custom_cfgs: NamedTuple, num_threads: int = 6
exp_id: str, algo: str, env_id: str, custom_cfgs: NamedTuple
) -> Tuple[float, float, float]:
"""Train a policy from exp-x config with OmniSafe.
Expand All @@ -36,16 +36,19 @@ def train(
custom_cfgs (NamedTuple): Custom configurations.
num_threads (int, optional): Number of threads. Defaults to 6.
"""
torch.set_num_threads(num_threads)
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
print(f'exp-x: {exp_id} is training...')
USE_REDIRECTION = True
if USE_REDIRECTION:
if not os.path.exists(custom_cfgs['data_dir']):
os.makedirs(custom_cfgs['data_dir'])
sys.stdout = open(f'{custom_cfgs["data_dir"]}terminal.log', 'w', encoding='utf-8')
sys.stderr = open(f'{custom_cfgs["data_dir"]}error.log', 'w', encoding='utf-8')
if not os.path.exists(custom_cfgs['logger_cfgs']['log_dir']):
os.makedirs(custom_cfgs['logger_cfgs']['log_dir'])
sys.stdout = open(
f'{custom_cfgs["logger_cfgs"]["log_dir"]}terminal.log', 'w', encoding='utf-8'
)
sys.stderr = open(
f'{custom_cfgs["logger_cfgs"]["log_dir"]}error.log', 'w', encoding='utf-8'
)
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs)
reward, cost, ep_len = agent.learn()
return reward, cost, ep_len
Expand All @@ -58,9 +61,11 @@ def train(
first_order_policy = ['CUP', 'FOCOPS']
second_order_policy = ['CPO', 'PCPO']
eg.add('algo', base_policy + naive_lagrange_policy + first_order_policy + second_order_policy)
eg.add('env_id', ['SafetyPointGoal1-v0'])
eg.add('logger_cfgs:use_wandb', [True])
eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming'])
eg.add('env_id', ['SafetyAntVelocity-v4'])
eg.add('logger_cfgs:use_wandb', [False])
eg.add('num_envs', [16])
eg.add('num_threads', [1])
# eg.add('logger_cfgs:wandb_project', ['omnisafe_jiaming'])
# eg.add('train_cfgs:total_steps', 2000)
# eg.add('algo_cfgs:update_cycle', 1000)
# eg.add('train_cfgs:vector_env_nums', 1)
Expand Down
7 changes: 4 additions & 3 deletions omnisafe/common/experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def unflatten_var(var):
return new_variants

# pylint: disable-next=too-many-locals
def run(self, thunk, num_pool=1, data_dir=None, is_test=False):
def run(self, thunk, num_pool=1, log_dir=None, is_test=False):
r"""Run each variant in the grid with function 'thunk'.
Note: 'thunk' must be either a callable function, or a string. If it is
Expand Down Expand Up @@ -362,8 +362,9 @@ def run(self, thunk, num_pool=1, data_dir=None, is_test=False):
print('current_config', var)
exp_name = '_'.join([k + '_' + str(v) for k, v in var.items()])
exp_names.append(exp_name)
data_dir = os.path.join('./', 'exp-x', self.name, exp_name, '')
var['data_dir'] = data_dir
if log_dir is None:
log_dir = os.path.join('./', 'exp-x', self.name, exp_name, '')
var['logger_cfgs'] = {'log_dir': log_dir}
results.append(pool.submit(thunk, idx, var['algo'], var['env_id'], var))
pool.shutdown()

Expand Down
2 changes: 1 addition & 1 deletion omnisafe/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Config(dict):
num_mini_batches: int
actor_lr: float
critic_lr: float
data_dir: str
log_dir: str
target_kl: float
batch_size: int
use_cost: bool
Expand Down

0 comments on commit 69086ab

Please sign in to comment.