Skip to content

Commit

Permalink
feat: update CLI for gpu and statistics tools (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
muchvo authored Mar 31, 2023
1 parent 14f5493 commit 5e91fb7
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/benchmarks/run_experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train(
# if you want to use CPU, please set gpu_id = None
# gpu_id = None

if set(gpu_id) > set(avaliable_gpus):
if not set(gpu_id).issubset(avaliable_gpus):
warnings.warn('The GPU ID is not available, use CPU instead.', stacklevel=1)
gpu_id = None

Expand Down
8 changes: 5 additions & 3 deletions omnisafe/common/experiment_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,11 @@ def run(self, thunk, num_pool=1, parent_dir=None, is_test=False, gpu_id=None):
device_id = gpu_id[idx % len(gpu_id)]
device = f'cuda:{device_id}'
var['train_cfgs'] = {'device': device}
no_seed_var = deepcopy(var)
no_seed_var.pop('seed', None)
exp_name = recursive_dict2json(no_seed_var)
clean_var = deepcopy(var)
clean_var.pop('seed', None)
if gpu_id is not None:
clean_var['train_cfgs'].pop('device', None)
exp_name = recursive_dict2json(clean_var)
hashed_exp_name = var['env_id'][:30] + '---' + hash_string(exp_name)
exp_names.append(':'.join((hashed_exp_name[:5], exp_name)))
exp_log_dir = os.path.join(self.log_dir, hashed_exp_name, '')
Expand Down
7 changes: 4 additions & 3 deletions omnisafe/common/statistics_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def draw_graph(
if it is specified, will only compare values in it.
compare_num (int): number of values to compare,
if it is specified, will combine any potential combination to compare.
cost_limit (float) the cost limit to show in graphs by a single line.
.. Note::
`values` and `compare_num` cannot be set at the same time.
"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def make_config_groups(self, parameter, parameter_values: list, values: list, co
compare_num <= len(parameter_values),
(
f'compare_num `{compare_num}` is larger than number of values '
'`{len(parameter_values)}` of parameter `{parameter}`',
'`{len(parameter_values)}` of parameter `{parameter}`'
),
)
# if compare_num is specified, will combine any potential combination to compare
Expand All @@ -192,7 +192,8 @@ def make_config_groups(self, parameter, parameter_values: list, values: list, co
group_config.pop(parameter)
# seed is not a parameter
group_config.pop('seed')

if 'train_cfgs' in group_config:
group_config['train_cfgs'].pop('device', None)
# combine all possible combinations of other parameters
# fix them in a single graph and only vary values of parameter which is specified by us
for pinned_config in self.dict_permutations(group_config):
Expand Down
133 changes: 131 additions & 2 deletions omnisafe/utils/command_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from typing import List

import numpy as np
import torch
import typer
import yaml
from rich.console import Console

import omnisafe
from omnisafe.common.experiment_grid import ExperimentGrid
from omnisafe.common.statistics_tools import StatisticsTools
from omnisafe.typing import NamedTuple, Tuple
from omnisafe.utils.tools import assert_with_exit, custom_cfgs_to_dict, update_dic

Expand Down Expand Up @@ -55,6 +57,15 @@ def train( # pylint: disable=too-many-arguments
os.path.abspath('.'),
help='directory to save logs, default is current directory',
),
plot: bool = typer.Option(False, help='whether to plot the training curve after training'),
render: bool = typer.Option(
False,
help='whether to render the trajectory of models saved during training',
),
evaluate: bool = typer.Option(
False,
help='whether to evaluate the trajectory of models saved during training',
),
custom_cfgs: List[str] = typer.Option([], help='custom configuration for training'),
):
"""Train a single policy in OmniSafe via command line.
Expand Down Expand Up @@ -103,6 +114,22 @@ def train( # pylint: disable=too-many-arguments
)
agent.learn()

if plot:
try:
agent.plot(smooth=1)
except RuntimeError:
console.print('failed to plot data', style='red bold')
if render:
try:
agent.render(num_episodes=10, render_mode='rgb_array', width=256, height=256)
except RuntimeError:
console.print('failed to render model', style='red bold')
if evaluate:
try:
agent.evaluate(num_episodes=10)
except RuntimeError:
console.print('failed to evaluate model', style='red bold')


def train_grid(
exp_id: str,
Expand Down Expand Up @@ -161,10 +188,23 @@ def benchmark(
...,
help='path to config file, it is supposed to be yaml file, e.g. ./configs/ppo.yaml',
),
gpu_range: str = typer.Option(
None,
help='range of gpu to use, the format is as same as range in python,'
'for example, use 2==range(2), 0:2==range(0,2), 0:2:1==range(0,2,1) to select gpu',
),
log_dir: str = typer.Option(
os.path.abspath('.'),
help='directory to save logs, default is current directory',
),
render: bool = typer.Option(
False,
help='whether to render the trajectory of models saved during training',
),
evaluate: bool = typer.Option(
False,
help='whether to evaluate the trajectory of models saved during training',
),
):
"""Benchmark algorithms configured by .yaml file in OmniSafe via command line.
Expand Down Expand Up @@ -200,11 +240,38 @@ def benchmark(
eg = ExperimentGrid(exp_name=exp_name)
for k, v in configs.items():
eg.add(key=k, vals=v)
eg.run(train_grid, num_pool=num_pool, parent_dir=log_dir)

gpu_id = None
if gpu_range is not None:
assert_with_exit(
len(gpu_range.split(':')) <= 3,
'gpu_range must be like x:y:z format,'
' which means using gpu in [x, y) with step size z',
)
# Set the device.
avaliable_gpus = list(range(torch.cuda.device_count()))
gpu_id = list(range(*[int(i) for i in gpu_range.split(':')]))

if not set(gpu_id).issubset(avaliable_gpus):
warnings.warn('The GPU ID is not available, use CPU instead.', stacklevel=1)
gpu_id = None

eg.run(train_grid, num_pool=num_pool, parent_dir=log_dir, gpu_id=gpu_id)

if render:
try:
eg.render(num_episodes=10, render_mode='rgb_array', width=256, height=256)
except RuntimeError:
console.print('failed to render model', style='red bold')
if evaluate:
try:
eg.evaluate(num_episodes=10)
except RuntimeError:
console.print('failed to evaluate model', style='red bold')


@app.command('eval')
def evaluate(
def evaluate_model(
result_dir: str = typer.Argument(
...,
help='directory of experiment results to evaluate, e.g. ./runs/PPO-{SafetyPointGoal1-v0}',
Expand Down Expand Up @@ -268,6 +335,15 @@ def train_config(
os.path.abspath('.'),
help='directory to save logs, default is current directory',
),
plot: bool = typer.Option(False, help='whether to plot the training curve after training'),
render: bool = typer.Option(
False,
help='whether to render the trajectory of models saved during training',
),
evaluate: bool = typer.Option(
False,
help='whether to evaluate the trajectory of models saved during training',
),
):
"""Train a policy configured by .yaml file in OmniSafe via command line.
Expand Down Expand Up @@ -297,6 +373,59 @@ def train_config(
agent = omnisafe.Agent(algo=args['algo'], env_id=args['env_id'], custom_cfgs=args)
agent.learn()

if plot:
try:
agent.plot(smooth=1)
except RuntimeError:
console.print('failed to plot data', style='red bold')
if render:
try:
agent.render(num_episodes=10, render_mode='rgb_array', width=256, height=256)
except RuntimeError:
console.print('failed to render model', style='red bold')
if evaluate:
try:
agent.evaluate(num_episodes=10)
except RuntimeError:
console.print('failed to evaluate model', style='red bold')


@app.command()
def analyze_grid(
path: str = typer.Argument(
...,
help='path of experiment directory, these experiments are launched by omnisafe via experiment grid',
),
parameter: str = typer.Argument(
...,
help='name of parameter to analyze',
),
compare_num: int = typer.Option(
None,
help='number of values to compare, if it is specified, will combine any potential combination to compare',
),
cost_limit: int = typer.Option(
None,
help='the cost limit to show in graphs by a single line',
),
):
"""Statistics tools for experiment grid.
Just specify in the name of the parameter of which value you want to compare,
then you can just specify how many values you want to compare in single graph at most,
and the function will automatically generate all possible combinations of the graph.
"""

tools = StatisticsTools()
tools.load_source(path)

tools.draw_graph(
parameter=parameter,
values=None,
compare_num=compare_num,
cost_limit=cost_limit,
)


if __name__ == '__main__':
app()
2 changes: 1 addition & 1 deletion tests/saved_source/benchmark_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ train_cfgs:vector_env_nums:
train_cfgs:torch_threads:
[1]
train_cfgs:total_steps:
2048
4096
algo_cfgs:update_cycle:
2048
seed:
Expand Down
2 changes: 1 addition & 1 deletion tests/saved_source/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ env_id:
SafetyAntVelocity-v4
train_cfgs:
total_steps:
1024
2048
vector_env_nums: 1
algo_cfgs:
update_cycle:
Expand Down
39 changes: 34 additions & 5 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ def test_benchmark():
'test_benchmark',
'2',
os.path.join(base_path, './saved_source/benchmark_config.yaml'),
# '--render',
'--evaluate',
# '--gpu-range',
# '0:1',
],
)
assert result.exit_code == 0
assert result.exit_code == 0, result.output


def test_train():
Expand All @@ -52,17 +56,26 @@ def test_train():
'algo_cfgs:update_cycle',
'--custom-cfgs',
'1024',
'--plot',
# '--render',
'--evaluate',
],
)
assert result.exit_code == 0
assert result.exit_code == 0, result.output


def test_train_config():
result = runner.invoke(
app,
['train-config', os.path.join(base_path, './saved_source/train_config.yaml')],
[
'train-config',
os.path.join(base_path, './saved_source/train_config.yaml'),
'--plot',
# '--render',
'--evaluate',
],
)
assert result.exit_code == 0
assert result.exit_code == 0, result.output


def test_eval():
Expand All @@ -80,4 +93,20 @@ def test_eval():
'--no-render',
],
)
assert result.exit_code == 0
assert result.exit_code == 0, result.output


def test_analyze_grid():
result = runner.invoke(
app,
[
'analyze-grid',
os.path.join(base_path, './saved_source/test_statistics_tools'),
'algo',
'--compare-num',
'2',
'--cost-limit',
'25',
],
)
assert result.exit_code == 0, result.output

0 comments on commit 5e91fb7

Please sign in to comment.