Skip to content

Commit

Permalink
feat(crabs): refine rendering for saved crabs policies (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
muchvo authored May 6, 2024
1 parent 93d4975 commit 080e6b8
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 78 deletions.
10 changes: 6 additions & 4 deletions omnisafe/adapter/crabs_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.control_barrier_function.crabs.models import MeanPolicy
from omnisafe.common.logger import Logger
from omnisafe.envs.crabs_env import CRABSEnv
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
Expand Down Expand Up @@ -55,14 +56,15 @@ def __init__( # pylint: disable=too-many-arguments
"""Initialize a instance of :class:`CRABSAdapter`."""
super().__init__(env_id, num_envs, seed, cfgs)
self._env: CRABSEnv
self._eval_env: CRABSEnv
self.n_expl_episodes = 0
self._max_ep_len = self._env.env.spec.max_episode_steps # type: ignore
self.horizon = self._max_ep_len

def eval_policy( # pylint: disable=too-many-locals
self,
episode: int,
agent: ConstraintActorQCritic,
agent: ConstraintActorQCritic | MeanPolicy,
logger: Logger,
) -> None:
"""Rollout the environment with deterministic agent action.
Expand All @@ -74,13 +76,13 @@ def eval_policy( # pylint: disable=too-many-locals
"""
for _ in range(episode):
ep_ret, ep_cost, ep_len = 0.0, 0.0, 0
obs, _ = self._eval_env.reset() # type: ignore
obs, _ = self._eval_env.reset()
obs = obs.to(self._device)

done = False
while not done:
act = agent.step(obs, deterministic=False)
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) # type: ignore
act = agent.step(obs, deterministic=True)
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act)
obs, reward, cost, terminated, truncated = (
torch.as_tensor(x, dtype=torch.float32, device=self._device)
for x in (obs, reward, cost, terminated, truncated)
Expand Down
77 changes: 26 additions & 51 deletions omnisafe/algorithms/off_policy/crabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@
from omnisafe.common.control_barrier_function.crabs.models import (
AddGaussianNoise,
CrabsCore,
EnsembleModel,
ExplorationPolicy,
GatedTransitionModel,
MeanPolicy,
MultiLayerPerceptron,
TransitionModel,
UniformPolicy,
)
from omnisafe.common.control_barrier_function.crabs.optimizers import (
Expand All @@ -46,7 +43,11 @@
SLangevinOptimizer,
StateBox,
)
from omnisafe.common.control_barrier_function.crabs.utils import Normalizer, get_pretrained_model
from omnisafe.common.control_barrier_function.crabs.utils import (
Normalizer,
create_model_and_trainer,
get_pretrained_model,
)
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic


Expand Down Expand Up @@ -115,48 +116,13 @@ def _init_model(self) -> None:
).to(self._device)
self.mean_policy = MeanPolicy(self._actor_critic.actor)

if self._cfgs.transition_model_cfgs.type == 'GatedTransitionModel':

def make_model(i):
return GatedTransitionModel(
self.dim_state,
self.normalizer,
[self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2],
self._cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)

self.model = EnsembleModel(
[make_model(i) for i in range(self._cfgs.transition_model_cfgs.n_ensemble)],
).to(self._device)
self.model_trainer = pl.Trainer(
max_epochs=0,
accelerator='gpu',
devices=[int(str(self._device)[-1])],
default_root_dir=self._cfgs.logger_cfgs.log_dir,
)
elif self._cfgs.transition_model_cfgs.type == 'TransitionModel':

def make_model(i):
return TransitionModel(
self.dim_state,
self.normalizer,
[self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2],
self._cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)

self.model = EnsembleModel(
[make_model(i) for i in range(self._cfgs.transition_model_cfgs.n_ensemble)],
).to(self._device)
self.model_trainer = pl.Trainer(
max_epochs=0,
accelerator='gpu',
devices=[int(str(self._device)[-1])],
default_root_dir=self._cfgs.logger_cfgs.log_dir,
)
else:
raise AssertionError(f'unknown model type {self._cfgs.transition_model_cfgs.type}')
self.model, self.model_trainer = create_model_and_trainer(
self._cfgs,
self.dim_state,
self.dim_action,
self.normalizer,
self._device,
)

def _init_log(self) -> None:
super()._init_log()
Expand All @@ -167,9 +133,18 @@ def _init_log(self) -> None:
what_to_save['obs_normalizer'] = self.normalizer
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()
self._logger.register_key('Metrics/RawPolicyEpRet', window_length=50)
self._logger.register_key('Metrics/RawPolicyEpCost', window_length=50)
self._logger.register_key('Metrics/RawPolicyEpLen', window_length=50)
self._logger.register_key(
'Metrics/RawPolicyEpRet',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/RawPolicyEpCost',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/RawPolicyEpLen',
window_length=self._cfgs.logger_cfgs.window_lens,
)

def _init(self) -> None:
"""The initialization of the algorithm.
Expand Down Expand Up @@ -282,7 +257,7 @@ def learn(self):
eval_start = time.time()
self._env.eval_policy(
episode=self._cfgs.train_cfgs.raw_policy_episodes,
agent=self._actor_critic,
agent=self.mean_policy,
logger=self._logger,
)

Expand Down Expand Up @@ -330,7 +305,7 @@ def learn(self):
eval_start = time.time()
self._env.eval_policy(
episode=self._cfgs.train_cfgs.raw_policy_episodes,
agent=self.mean_policy, # type: ignore
agent=self.mean_policy,
logger=self._logger,
)
eval_time += time.time() - eval_start
Expand Down
30 changes: 24 additions & 6 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,32 @@ def _init_log(self) -> None:
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()

self._logger.register_key('Metrics/EpRet', window_length=50)
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)
self._logger.register_key(
'Metrics/EpRet',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/EpCost',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/EpLen',
window_length=self._cfgs.logger_cfgs.window_lens,
)

if self._cfgs.train_cfgs.eval_episodes > 0:
self._logger.register_key('Metrics/TestEpRet', window_length=50)
self._logger.register_key('Metrics/TestEpCost', window_length=50)
self._logger.register_key('Metrics/TestEpLen', window_length=50)
self._logger.register_key(
'Metrics/TestEpRet',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/TestEpCost',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/TestEpLen',
window_length=self._cfgs.logger_cfgs.window_lens,
)

self._logger.register_key('Train/Epoch')
self._logger.register_key('Train/LR')
Expand Down
15 changes: 12 additions & 3 deletions omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,18 @@ def _init_log(self) -> None:
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()

self._logger.register_key('Metrics/EpRet', window_length=50)
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)
self._logger.register_key(
'Metrics/EpRet',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/EpCost',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/EpLen',
window_length=self._cfgs.logger_cfgs.window_lens,
)

self._logger.register_key('Train/Epoch')
self._logger.register_key('Train/Entropy')
Expand Down
65 changes: 65 additions & 0 deletions omnisafe/common/control_barrier_function/crabs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@
# ==============================================================================
"""Utils for CRABS."""
# pylint: disable=all
from __future__ import annotations

import os

import pytorch_lightning as pl
import requests
import torch
import torch.nn as nn
from torch import load

from omnisafe.common.control_barrier_function.crabs.models import (
EnsembleModel,
GatedTransitionModel,
TransitionModel,
)


class Normalizer(nn.Module):
"""Normalizes input data to have zero mean and unit variance.
Expand Down Expand Up @@ -119,3 +128,59 @@ def get_pretrained_model(model_path, model_url, device):
print('Model found locally.')

return load(model_path, map_location=device)


def create_model_and_trainer(cfgs, dim_state, dim_action, normalizer, device):
"""Create world model and trainer.
Args:
cfgs: Configs.
dim_state: Dimension of the state.
dim_action: Dimension of the action.
normalizer: Observation normalizer.
device: Device to load the model.
Returns:
Tuple[nn.Module, pl.Trainer]: World model and trainer.
"""

def make_model(i, model_type) -> nn.Module:
if model_type == 'GatedTransitionModel':
return GatedTransitionModel(
dim_state,
normalizer,
[dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)
if model_type == 'TransitionModel':
return TransitionModel(
dim_state,
normalizer,
[dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)
raise AssertionError(f'unknown model type {model_type}')

model_type = cfgs.transition_model_cfgs.type
models = [make_model(i, model_type) for i in range(cfgs.transition_model_cfgs.n_ensemble)]

model = EnsembleModel(models).to(device)

devices: list[int] | int

if str(device).startswith('cuda'):
accelerator = 'gpu'
devices = [int(str(device)[-1])]
else:
accelerator = 'cpu'
devices = torch.get_num_threads()
trainer = pl.Trainer(
max_epochs=0,
accelerator=accelerator,
devices=devices,
default_root_dir=cfgs.logger_cfgs.log_dir,
)

return model, trainer
4 changes: 2 additions & 2 deletions omnisafe/common/offline/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__( # pylint: disable=too-many-branches
# Load data from local .npz file
try:
data = np.load(dataset_name)
except Exception as e:
except (ValueError, OSError) as e:
raise ValueError(f'Failed to load data from {dataset_name}') from e

else:
Expand Down Expand Up @@ -284,7 +284,7 @@ def __init__( # pylint: disable=too-many-branches, super-init-not-called
# Load data from local .npz file
try:
data = np.load(dataset_name)
except Exception as e:
except (ValueError, OSError) as e:
raise ValueError(f'Failed to load data from {dataset_name}') from e

else:
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/configs/off-policy/CRABS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ defaults:
# save logger path
log_dir: "./runs"
# save model path
window_lens: 10
window_lens: 6
# model configurations
model_cfgs:
# weight initialization mode
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/configs/off-policy/DDPG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ defaults:
# save logger path
log_dir: "./runs"
# save model path
window_lens: 10
window_lens: 50
# model configurations
model_cfgs:
# weight initialization mode
Expand Down
2 changes: 1 addition & 1 deletion omnisafe/configs/on-policy/PolicyGradient.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ defaults:
# save logger path
log_dir: "./runs"
# save model path
window_lens: 100
window_lens: 50
# model configurations
model_cfgs:
# weight initialization mode
Expand Down
2 changes: 0 additions & 2 deletions omnisafe/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================
"""Environment API for OmniSafe."""

from contextlib import suppress

from omnisafe.envs import classic_control
from omnisafe.envs.core import CMDP, env_register, make, support_envs
from omnisafe.envs.crabs_env import CRABSEnv
Expand Down
Loading

0 comments on commit 080e6b8

Please sign in to comment.