Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(crabs): refine rendering for saved crabs policies #330

Merged
merged 37 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bcdaba1
feat(cbf): support crabs as a representative control barrier function…
muchvo Apr 18, 2024
182071f
update style
muchvo Apr 19, 2024
9a31ee3
update style
muchvo Apr 21, 2024
fecf4db
Merge branch 'main' into dev-crabs
muchvo Apr 21, 2024
8d352d0
update
muchvo Apr 22, 2024
ac0b10a
Merge branch 'dev-crabs' of https://github.com/muchvo/omnisafe into d…
muchvo Apr 22, 2024
67eae76
update
muchvo Apr 22, 2024
320ea2f
update
muchvo Apr 22, 2024
1af97d5
update
muchvo Apr 22, 2024
f28c16a
update
muchvo Apr 22, 2024
84738e5
update
muchvo Apr 22, 2024
22abb0c
update
muchvo Apr 22, 2024
8f8b459
update
muchvo Apr 22, 2024
5e1b3e3
update
muchvo Apr 22, 2024
0336290
update
muchvo Apr 22, 2024
06a0f10
update
muchvo Apr 22, 2024
c1f7620
update
muchvo Apr 22, 2024
50afc8d
update
muchvo Apr 23, 2024
9611e30
update
muchvo Apr 30, 2024
b0c0e2c
update
muchvo Apr 30, 2024
9787bdb
update
muchvo Apr 30, 2024
bba02c7
update
muchvo Apr 30, 2024
a0d10b8
update
muchvo Apr 30, 2024
9a284a7
feat(crabs): refine rendering for saved crabs policies.
muchvo May 5, 2024
76fd811
Merge branch 'main' of https://github.com/muchvo/omnisafe into refine…
muchvo May 5, 2024
cb45d31
fix pylint
muchvo May 5, 2024
5287814
update .coveragerc
muchvo May 5, 2024
cd571c1
update pyproject.toml
muchvo May 5, 2024
5be00f5
improve code style
muchvo May 5, 2024
960ef39
improve code style
muchvo May 5, 2024
fd48922
improve code style
muchvo May 5, 2024
e06df1c
add doscstring
muchvo May 5, 2024
e76602a
improve code style
muchvo May 5, 2024
b5895a7
fix comment
muchvo May 6, 2024
860d588
fix comment
muchvo May 6, 2024
f6a046c
fix comment
muchvo May 6, 2024
7cf39ca
fix comment
muchvo May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions omnisafe/adapter/crabs_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.control_barrier_function.crabs.models import MeanPolicy
from omnisafe.common.logger import Logger
from omnisafe.envs.crabs_env import CRABSEnv
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
Expand Down Expand Up @@ -55,14 +56,15 @@ def __init__( # pylint: disable=too-many-arguments
"""Initialize a instance of :class:`CRABSAdapter`."""
super().__init__(env_id, num_envs, seed, cfgs)
self._env: CRABSEnv
self._eval_env: CRABSEnv
self.n_expl_episodes = 0
self._max_ep_len = self._env.env.spec.max_episode_steps # type: ignore
self.horizon = self._max_ep_len

def eval_policy( # pylint: disable=too-many-locals
self,
episode: int,
agent: ConstraintActorQCritic,
agent: ConstraintActorQCritic | MeanPolicy,
logger: Logger,
) -> None:
"""Rollout the environment with deterministic agent action.
Expand All @@ -74,13 +76,13 @@ def eval_policy( # pylint: disable=too-many-locals
"""
for _ in range(episode):
ep_ret, ep_cost, ep_len = 0.0, 0.0, 0
obs, _ = self._eval_env.reset() # type: ignore
obs, _ = self._eval_env.reset()
obs = obs.to(self._device)

done = False
while not done:
act = agent.step(obs, deterministic=False)
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) # type: ignore
act = agent.step(obs, deterministic=True)
obs, reward, cost, terminated, truncated, info = self._eval_env.step(act)
obs, reward, cost, terminated, truncated = (
torch.as_tensor(x, dtype=torch.float32, device=self._device)
for x in (obs, reward, cost, terminated, truncated)
Expand Down
68 changes: 17 additions & 51 deletions omnisafe/algorithms/off_policy/crabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@
from omnisafe.common.control_barrier_function.crabs.models import (
AddGaussianNoise,
CrabsCore,
EnsembleModel,
ExplorationPolicy,
GatedTransitionModel,
MeanPolicy,
MultiLayerPerceptron,
TransitionModel,
UniformPolicy,
)
from omnisafe.common.control_barrier_function.crabs.optimizers import (
Expand All @@ -46,7 +43,11 @@
SLangevinOptimizer,
StateBox,
)
from omnisafe.common.control_barrier_function.crabs.utils import Normalizer, get_pretrained_model
from omnisafe.common.control_barrier_function.crabs.utils import (
Normalizer,
create_model_and_trainer,
get_pretrained_model,
)
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic


Expand Down Expand Up @@ -115,48 +116,13 @@ def _init_model(self) -> None:
).to(self._device)
self.mean_policy = MeanPolicy(self._actor_critic.actor)

if self._cfgs.transition_model_cfgs.type == 'GatedTransitionModel':

def make_model(i):
return GatedTransitionModel(
self.dim_state,
self.normalizer,
[self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2],
self._cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)

self.model = EnsembleModel(
[make_model(i) for i in range(self._cfgs.transition_model_cfgs.n_ensemble)],
).to(self._device)
self.model_trainer = pl.Trainer(
max_epochs=0,
accelerator='gpu',
devices=[int(str(self._device)[-1])],
default_root_dir=self._cfgs.logger_cfgs.log_dir,
)
elif self._cfgs.transition_model_cfgs.type == 'TransitionModel':

def make_model(i):
return TransitionModel(
self.dim_state,
self.normalizer,
[self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2],
self._cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)

self.model = EnsembleModel(
[make_model(i) for i in range(self._cfgs.transition_model_cfgs.n_ensemble)],
).to(self._device)
self.model_trainer = pl.Trainer(
max_epochs=0,
accelerator='gpu',
devices=[int(str(self._device)[-1])],
default_root_dir=self._cfgs.logger_cfgs.log_dir,
)
else:
raise AssertionError(f'unknown model type {self._cfgs.transition_model_cfgs.type}')
self.model, self.model_trainer = create_model_and_trainer(
self._cfgs,
self.dim_state,
self.dim_action,
self.normalizer,
self._device,
)

def _init_log(self) -> None:
super()._init_log()
Expand All @@ -167,9 +133,9 @@ def _init_log(self) -> None:
what_to_save['obs_normalizer'] = self.normalizer
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()
self._logger.register_key('Metrics/RawPolicyEpRet', window_length=50)
self._logger.register_key('Metrics/RawPolicyEpCost', window_length=50)
self._logger.register_key('Metrics/RawPolicyEpLen', window_length=50)
self._logger.register_key('Metrics/RawPolicyEpRet', window_length=6)
self._logger.register_key('Metrics/RawPolicyEpCost', window_length=6)
self._logger.register_key('Metrics/RawPolicyEpLen', window_length=6)

def _init(self) -> None:
"""The initialization of the algorithm.
Expand Down Expand Up @@ -282,7 +248,7 @@ def learn(self):
eval_start = time.time()
self._env.eval_policy(
episode=self._cfgs.train_cfgs.raw_policy_episodes,
agent=self._actor_critic,
agent=self.mean_policy,
logger=self._logger,
)

Expand Down Expand Up @@ -330,7 +296,7 @@ def learn(self):
eval_start = time.time()
self._env.eval_policy(
episode=self._cfgs.train_cfgs.raw_policy_episodes,
agent=self.mean_policy, # type: ignore
agent=self.mean_policy,
logger=self._logger,
)
eval_time += time.time() - eval_start
Expand Down
6 changes: 3 additions & 3 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def _init_log(self) -> None:
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()

self._logger.register_key('Metrics/EpRet', window_length=50)
self._logger.register_key('Metrics/EpCost', window_length=50)
self._logger.register_key('Metrics/EpLen', window_length=50)
self._logger.register_key('Metrics/EpRet', window_length=6)
self._logger.register_key('Metrics/EpCost', window_length=6)
self._logger.register_key('Metrics/EpLen', window_length=6)

Gaiejj marked this conversation as resolved.
Show resolved Hide resolved
if self._cfgs.train_cfgs.eval_episodes > 0:
self._logger.register_key('Metrics/TestEpRet', window_length=50)
Expand Down
64 changes: 64 additions & 0 deletions omnisafe/common/control_barrier_function/crabs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
"""Utils for CRABS."""
# pylint: disable=all
import os
from typing import List, Union
Gaiejj marked this conversation as resolved.
Show resolved Hide resolved

import pytorch_lightning as pl
import requests
import torch
import torch.nn as nn
from torch import load

from omnisafe.common.control_barrier_function.crabs.models import (
EnsembleModel,
GatedTransitionModel,
TransitionModel,
)


class Normalizer(nn.Module):
"""Normalizes input data to have zero mean and unit variance.
Expand Down Expand Up @@ -119,3 +127,59 @@ def get_pretrained_model(model_path, model_url, device):
print('Model found locally.')

return load(model_path, map_location=device)


def create_model_and_trainer(cfgs, dim_state, dim_action, normalizer, device):
"""Create world model and trainer.

Args:
cfgs: Configs.
dim_state: Dimension of the state.
dim_action: Dimension of the action.
normalizer: Observation normalizer.
device: Device to load the model.

Returns:
Tuple[nn.Module, pl.Trainer]: World model and trainer.
"""

def make_model(i, model_type) -> nn.Module:
if model_type == 'GatedTransitionModel':
return GatedTransitionModel(
dim_state,
normalizer,
[dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)
if model_type == 'TransitionModel':
return TransitionModel(
dim_state,
normalizer,
[dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
cfgs.transition_model_cfgs.train,
name=f'model-{i}',
)
raise AssertionError(f'unknown model type {model_type}')

model_type = cfgs.transition_model_cfgs.type
models = [make_model(i, model_type) for i in range(cfgs.transition_model_cfgs.n_ensemble)]

model = EnsembleModel(models).to(device)

devices: Union[List[int], int]

if str(device).startswith('cuda'):
accelerator = 'gpu'
devices = [int(str(device)[-1])]
else:
accelerator = 'cpu'
devices = torch.get_num_threads()
trainer = pl.Trainer(
max_epochs=0,
accelerator=accelerator,
devices=devices,
default_root_dir=cfgs.logger_cfgs.log_dir,
)

return model, trainer
4 changes: 2 additions & 2 deletions omnisafe/common/offline/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__( # pylint: disable=too-many-branches
# Load data from local .npz file
try:
data = np.load(dataset_name)
except Exception as e:
except (ValueError, OSError) as e:
raise ValueError(f'Failed to load data from {dataset_name}') from e

else:
Expand Down Expand Up @@ -284,7 +284,7 @@ def __init__( # pylint: disable=too-many-branches, super-init-not-called
# Load data from local .npz file
try:
data = np.load(dataset_name)
except Exception as e:
except (ValueError, OSError) as e:
raise ValueError(f'Failed to load data from {dataset_name}') from e

else:
Expand Down
2 changes: 0 additions & 2 deletions omnisafe/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================
"""Environment API for OmniSafe."""

from contextlib import suppress

from omnisafe.envs import classic_control
from omnisafe.envs.core import CMDP, env_register, make, support_envs
from omnisafe.envs.crabs_env import CRABSEnv
Expand Down
10 changes: 7 additions & 3 deletions omnisafe/envs/classic_control/envs_from_crabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,20 @@ def __init__(
task='upright',
random_reset=False,
violation_penalty=10,
**kwargs,
) -> None:
"""Initialize the environment."""
self.threshold = threshold
self.task = task
self.random_reset = random_reset
self.violation_penalty = violation_penalty
super().__init__()
super().__init__(**kwargs)
EzPickle.__init__(
self,
threshold=threshold,
task=task,
random_reset=random_reset,
**kwargs,
) # deepcopy calls `get_state`

def reset_model(self):
Expand Down Expand Up @@ -156,9 +158,10 @@ def __init__(
task='swing',
random_reset=False,
violation_penalty=10,
**kwargs,
) -> None:
"""Initialize the environment."""
super().__init__(threshold=threshold, task=task)
super().__init__(threshold=threshold, task=task, **kwargs)


class SafeInvertedPendulumMoveEnv(SafeInvertedPendulumEnv):
Expand All @@ -170,9 +173,10 @@ def __init__(
task='move',
random_reset=False,
violation_penalty=10,
**kwargs,
) -> None:
"""Initialize the environment."""
super().__init__(threshold=threshold, task=task)
super().__init__(threshold=threshold, task=task, **kwargs)


register(id='SafeInvertedPendulum-v2', entry_point=SafeInvertedPendulumEnv, max_episode_steps=1000)
Expand Down
Loading
Loading