From a5b7e6f1c8d5594f913b37bd1af542ef1ffe693f Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 6 Mar 2023 01:15:37 +0800 Subject: [PATCH 1/5] fix(logger, wrapper): support csv file and velocity tasks --- .../algorithms/on_policy/base/policy_gradient.py | 13 ++++++------- omnisafe/common/logger.py | 16 +++++++++------- omnisafe/configs/on-policy/CPO.yaml | 2 +- omnisafe/envs/wrapper.py | 3 ++- omnisafe/utils/math.py | 6 +++--- 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index ae3ec1d22..cd194e35a 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -15,7 +15,7 @@ """Implementation of the Policy Gradient algorithm.""" import time -from typing import Dict, Tuple, Union +from typing import Any, Dict, Tuple, Union import torch import torch.nn as nn @@ -98,12 +98,11 @@ def _init_log(self) -> None: config=self._cfgs, ) - obs_normalizer = self._env.save()['obs_normalizer'] - what_to_save = { - 'pi': self._actor_critic.actor, - 'obs_normalizer': obs_normalizer, - } - + what_to_save: Dict[str, Any] = {} + what_to_save['pi'] = self._actor_critic.actor + if self._cfgs.algo_cfgs.obs_normalize: + obs_normalizer = self._env.save()['obs_normalizer'] + what_to_save['obs_normalizer'] = obs_normalizer self._logger.setup_torch_saver(what_to_save) self._logger.torch_save() diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 73a6b41a1..5938cc600 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -15,6 +15,7 @@ """Implementation of the Logger.""" import atexit +import csv import os import time from collections import deque @@ -96,7 +97,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self, output_dir: str, exp_name: str, - output_fname: str = 'progress.txt', + output_fname: str = 'progress.csv', verbose: bool = True, seed: int = 0, use_tensorboard: bool = True, @@ -123,6 +124,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ) atexit.register(self._output_file.close) self.log(f'Logging data to {self._output_file.name}', 'cyan', bold=True) + self._csv_writer = csv.writer(self._output_file) self._epoch: int = 0 self._first_row: bool = True @@ -144,10 +146,10 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self._tensorboard_writer = SummaryWriter(log_dir=os.path.join(self._log_dir, 'tb')) if self._use_wandb and self._maste_proc: - project: str = self._config.logger_cfgs.get('wandb_project', 'omnisafe') - name: str = f'{exp_name}-{relpath}' - print('project', project, 'name', name) - wandb.init(project=project, name=name, dir=self._log_dir, config=config) + project: str = self._config.get('wandb_project', 'omnisafe') + name: str = self._config.get('wandb_name', f'{exp_name}/{relpath}') + entity: str = self._config.get('wandb_entity', None) + wandb.init(project=project, name=name, entity=entity, dir=self._log_dir, config=config) if config is not None: wandb.config.update(config) if models is not None: @@ -277,9 +279,9 @@ def dump_tabular(self) -> None: self._proc_bar.update(1) if self._first_row: - self._output_file.write(' '.join(self._current_row.keys()) + '\n') + self._csv_writer.writerow(self._current_row.keys()) self._first_row = False - self._output_file.write(' '.join(map(str, self._current_row.values())) + '\n') + self._csv_writer.writerow(self._current_row.values()) self._output_file.flush() if self._use_tensorboard: diff --git a/omnisafe/configs/on-policy/CPO.yaml b/omnisafe/configs/on-policy/CPO.yaml index 36054c030..9caa93101 100644 --- a/omnisafe/configs/on-policy/CPO.yaml +++ b/omnisafe/configs/on-policy/CPO.yaml @@ -37,7 +37,7 @@ defaults: # batch size for each iteration batch_size: 16384 # target kl divergence - target_kl: 0.01 + target_kl: 0.02 # entropy coefficient entropy_coef: 0.0 # normalize reward diff --git a/omnisafe/envs/wrapper.py b/omnisafe/envs/wrapper.py index 0a774a1d8..70f658c9a 100644 --- a/omnisafe/envs/wrapper.py +++ b/omnisafe/envs/wrapper.py @@ -260,7 +260,7 @@ def step( action = self._old_min_action + (self._old_max_action - self._old_min_action) * ( action - self._min_action ) / (self._max_action - self._min_action) - return super().step(action) + return super().step(action.numpy()) class Unsqueeze(Wrapper): @@ -283,6 +283,7 @@ def __init__(self, env: CMDP) -> None: def step( self, action: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]: + action = action.squeeze(0) obs, reward, cost, terminated, truncated, info = super().step(action) obs, reward, cost, terminated, truncated = map( lambda x: x.unsqueeze(0), (obs, reward, cost, terminated, truncated) diff --git a/omnisafe/utils/math.py b/omnisafe/utils/math.py index b8e936f5f..0e43c467c 100644 --- a/omnisafe/utils/math.py +++ b/omnisafe/utils/math.py @@ -73,7 +73,7 @@ def gaussian_kl( \mu_q) - k + log(\frac{det(\Sigma_p)}{det(\Sigma_q)})) where :math:`\mu_p` and :math:`\mu_q` are the mean of :math:`p` and :math:`q`, respectively. - :math:`\Sigma_p` and :math:`\Sigma_q` are the covariance of :math:`p` and :math:`q`, respectively. + :math:`\Sigma_p` and :math:`\Sigma_q` are the co-variance of :math:`p` and :math:`q`, respectively. :math:`k` is the dimension of the distribution. For more details, @@ -83,8 +83,8 @@ def gaussian_kl( Args: mean_p (torch.Tensor): mean of the first distribution, shape (B, n) mean_q (torch.Tensor): mean of the second distribution, shape (B, n) - var_p (torch.Tensor): covariance of the first distribution, shape (B, n, n) - var_q (torch.Tensor): covariance of the second distribution, shape (B, n, n) + var_p (torch.Tensor): co-variance of the first distribution, shape (B, n, n) + var_q (torch.Tensor): co-variance of the second distribution, shape (B, n, n) """ len_q = var_q.size(-1) mean_p = mean_p.unsqueeze(-1) # (B, n, 1) From 1a4f3fb207188c6e5a30f994fc25b52e4a9166fd Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 6 Mar 2023 01:33:18 +0800 Subject: [PATCH 2/5] fix(logger): fix wandb error --- omnisafe/common/logger.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 5938cc600..5a0c36b29 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -144,12 +144,11 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals if self._use_tensorboard and self._maste_proc: self._tensorboard_writer = SummaryWriter(log_dir=os.path.join(self._log_dir, 'tb')) + project: str = self._config.logger_cfgs.get('wandb_project', 'omnisafe') + name: str = f'{exp_name}-{relpath}' + print('project', project, 'name', name) + wandb.init(project=project, name=name, dir=self._log_dir, config=config) - if self._use_wandb and self._maste_proc: - project: str = self._config.get('wandb_project', 'omnisafe') - name: str = self._config.get('wandb_name', f'{exp_name}/{relpath}') - entity: str = self._config.get('wandb_entity', None) - wandb.init(project=project, name=name, entity=entity, dir=self._log_dir, config=config) if config is not None: wandb.config.update(config) if models is not None: From ef20ec62db8ceda82a6b3195f9fa8651ead1ae05 Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 6 Mar 2023 01:40:33 +0800 Subject: [PATCH 3/5] wip --- omnisafe/common/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 5a0c36b29..0d08f51f3 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -148,7 +148,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals name: str = f'{exp_name}-{relpath}' print('project', project, 'name', name) wandb.init(project=project, name=name, dir=self._log_dir, config=config) - + if self._use_wandb and self._maste_proc: if config is not None: wandb.config.update(config) if models is not None: From 1616c934f823984de0ec7f886bbaa385a3aa4131 Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 6 Mar 2023 01:48:29 +0800 Subject: [PATCH 4/5] wip --- omnisafe/common/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 0d08f51f3..7f40574f6 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -144,11 +144,11 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals if self._use_tensorboard and self._maste_proc: self._tensorboard_writer = SummaryWriter(log_dir=os.path.join(self._log_dir, 'tb')) + if self._use_wandb and self._maste_proc: project: str = self._config.logger_cfgs.get('wandb_project', 'omnisafe') name: str = f'{exp_name}-{relpath}' print('project', project, 'name', name) wandb.init(project=project, name=name, dir=self._log_dir, config=config) - if self._use_wandb and self._maste_proc: if config is not None: wandb.config.update(config) if models is not None: From 2b37f170b26fef38d83b4af0e397b0192eb32a6b Mon Sep 17 00:00:00 2001 From: Gaiejj <524339208@qq.com> Date: Mon, 6 Mar 2023 01:49:10 +0800 Subject: [PATCH 5/5] wip --- omnisafe/common/logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 7f40574f6..0a7b56892 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -144,6 +144,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals if self._use_tensorboard and self._maste_proc: self._tensorboard_writer = SummaryWriter(log_dir=os.path.join(self._log_dir, 'tb')) + if self._use_wandb and self._maste_proc: project: str = self._config.logger_cfgs.get('wandb_project', 'omnisafe') name: str = f'{exp_name}-{relpath}'