Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(logger, wrapper): support csv file and velocity tasks #131

Merged
merged 5 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Implementation of the Policy Gradient algorithm."""

import time
from typing import Dict, Tuple, Union
from typing import Any, Dict, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -98,12 +98,11 @@ def _init_log(self) -> None:
config=self._cfgs,
)

obs_normalizer = self._env.save()['obs_normalizer']
what_to_save = {
'pi': self._actor_critic.actor,
'obs_normalizer': obs_normalizer,
}

what_to_save: Dict[str, Any] = {}
what_to_save['pi'] = self._actor_critic.actor
if self._cfgs.algo_cfgs.obs_normalize:
obs_normalizer = self._env.save()['obs_normalizer']
what_to_save['obs_normalizer'] = obs_normalizer
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()

Expand Down
16 changes: 9 additions & 7 deletions omnisafe/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Implementation of the Logger."""

import atexit
import csv
import os
import time
from collections import deque
Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
output_dir: str,
exp_name: str,
output_fname: str = 'progress.txt',
output_fname: str = 'progress.csv',
verbose: bool = True,
seed: int = 0,
use_tensorboard: bool = True,
Expand All @@ -123,6 +124,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
)
atexit.register(self._output_file.close)
self.log(f'Logging data to {self._output_file.name}', 'cyan', bold=True)
self._csv_writer = csv.writer(self._output_file)

self._epoch: int = 0
self._first_row: bool = True
Expand All @@ -144,10 +146,10 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self._tensorboard_writer = SummaryWriter(log_dir=os.path.join(self._log_dir, 'tb'))

if self._use_wandb and self._maste_proc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

project: str = self._config.logger_cfgs.get('wandb_project', 'omnisafe')
name: str = f'{exp_name}-{relpath}'
print('project', project, 'name', name)
wandb.init(project=project, name=name, dir=self._log_dir, config=config)
project: str = self._config.get('wandb_project', 'omnisafe')
name: str = self._config.get('wandb_name', f'{exp_name}/{relpath}')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

entity: str = self._config.get('wandb_entity', None)
wandb.init(project=project, name=name, entity=entity, dir=self._log_dir, config=config)
if config is not None:
wandb.config.update(config)
if models is not None:
Expand Down Expand Up @@ -277,9 +279,9 @@ def dump_tabular(self) -> None:
self._proc_bar.update(1)

if self._first_row:
self._output_file.write(' '.join(self._current_row.keys()) + '\n')
self._csv_writer.writerow(self._current_row.keys())
self._first_row = False
self._output_file.write(' '.join(map(str, self._current_row.values())) + '\n')
self._csv_writer.writerow(self._current_row.values())
self._output_file.flush()

if self._use_tensorboard:
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/configs/on-policy/CPO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ defaults:
# batch size for each iteration
batch_size: 16384
# target kl divergence
target_kl: 0.01
target_kl: 0.02
# entropy coefficient
entropy_coef: 0.0
# normalize reward
Expand Down
3 changes: 2 additions & 1 deletion omnisafe/envs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def step(
action = self._old_min_action + (self._old_max_action - self._old_min_action) * (
action - self._min_action
) / (self._max_action - self._min_action)
return super().step(action)
return super().step(action.numpy())


class Unsqueeze(Wrapper):
Expand All @@ -283,6 +283,7 @@ def __init__(self, env: CMDP) -> None:
def step(
self, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
action = action.squeeze(0)
obs, reward, cost, terminated, truncated, info = super().step(action)
obs, reward, cost, terminated, truncated = map(
lambda x: x.unsqueeze(0), (obs, reward, cost, terminated, truncated)
Expand Down
6 changes: 3 additions & 3 deletions omnisafe/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def gaussian_kl(
\mu_q) - k + log(\frac{det(\Sigma_p)}{det(\Sigma_q)}))

where :math:`\mu_p` and :math:`\mu_q` are the mean of :math:`p` and :math:`q`, respectively.
:math:`\Sigma_p` and :math:`\Sigma_q` are the covariance of :math:`p` and :math:`q`, respectively.
:math:`\Sigma_p` and :math:`\Sigma_q` are the co-variance of :math:`p` and :math:`q`, respectively.
:math:`k` is the dimension of the distribution.

For more details,
Expand All @@ -83,8 +83,8 @@ def gaussian_kl(
Args:
mean_p (torch.Tensor): mean of the first distribution, shape (B, n)
mean_q (torch.Tensor): mean of the second distribution, shape (B, n)
var_p (torch.Tensor): covariance of the first distribution, shape (B, n, n)
var_q (torch.Tensor): covariance of the second distribution, shape (B, n, n)
var_p (torch.Tensor): co-variance of the first distribution, shape (B, n, n)
var_q (torch.Tensor): co-variance of the second distribution, shape (B, n, n)
"""
len_q = var_q.size(-1)
mean_p = mean_p.unsqueeze(-1) # (B, n, 1)
Expand Down