From ef3061675d2f4b5b48c05d757b4fb475be1c7b9b Mon Sep 17 00:00:00 2001 From: Gaiejj Date: Fri, 3 May 2024 02:37:30 +0800 Subject: [PATCH 1/3] chore: fix code style --- omnisafe/adapter/offpolicy_latent_adapter.py | 272 +++++++++++++++ omnisafe/algorithms/__init__.py | 1 + omnisafe/algorithms/off_policy/__init__.py | 2 + omnisafe/algorithms/off_policy/ddpg.py | 4 +- omnisafe/algorithms/off_policy/safe_slac.py | 293 ++++++++++++++++ omnisafe/common/buffer/__init__.py | 3 +- omnisafe/common/buffer/base.py | 51 +++ omnisafe/common/buffer/offpolicy_buffer.py | 79 ++++- omnisafe/common/latent.py | 344 +++++++++++++++++++ omnisafe/configs/off-policy/SafeSLAC.yaml | 148 ++++++++ omnisafe/envs/__init__.py | 1 + omnisafe/envs/safety_gymnasium_vision_env.py | 194 +++++++++++ omnisafe/utils/model.py | 45 ++- omnisafe/utils/tools.py | 76 +++- 14 files changed, 1507 insertions(+), 6 deletions(-) create mode 100644 omnisafe/adapter/offpolicy_latent_adapter.py create mode 100644 omnisafe/algorithms/off_policy/safe_slac.py create mode 100644 omnisafe/common/latent.py create mode 100644 omnisafe/configs/off-policy/SafeSLAC.yaml create mode 100644 omnisafe/envs/safety_gymnasium_vision_env.py diff --git a/omnisafe/adapter/offpolicy_latent_adapter.py b/omnisafe/adapter/offpolicy_latent_adapter.py new file mode 100644 index 000000000..105ee8ad1 --- /dev/null +++ b/omnisafe/adapter/offpolicy_latent_adapter.py @@ -0,0 +1,272 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""OffPolicy Latent Adapter for OmniSafe.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from gymnasium.spaces import Box + +from omnisafe.adapter.online_adapter import OnlineAdapter +from omnisafe.common.buffer import OffPolicySequenceBuffer +from omnisafe.common.latent import CostLatentModel +from omnisafe.common.logger import Logger +from omnisafe.envs.wrapper import ( + ActionRepeat, + ActionScale, + AutoReset, + CostNormalize, + ObsNormalize, + RewardNormalize, + TimeLimit, + Unsqueeze, +) +from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic +from omnisafe.utils.config import Config +from omnisafe.utils.model import ObservationConcator + + +class OffPolicyLatentAdapter(OnlineAdapter): + _current_obs: torch.Tensor + _ep_ret: torch.Tensor + _ep_cost: torch.Tensor + _ep_len: torch.Tensor + + def __init__( # pylint: disable=too-many-arguments + self, + env_id: str, + num_envs: int, + seed: int, + cfgs: Config, + ) -> None: + """Initialize a instance of :class:`OffPolicyAdapter`.""" + super().__init__(env_id, num_envs, seed, cfgs) + self._observation_concator: ObservationConcator = ObservationConcator( + self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2, + self.action_space.shape, + self._cfgs.algo_cfgs.num_sequences, + device=self._device, + ) + self._current_obs, _ = self.reset() + self._max_ep_len: int = 1000 + self._reset_log() + self.z1 = None + self.z2 = None + self._reset_sequence_queue = False + + def _wrapper( + self, + obs_normalize: bool = True, + reward_normalize: bool = True, + cost_normalize: bool = True, + ) -> None: + """Wrapper the environment. + + .. hint:: + OmniSafe supports the following wrappers: + + +-----------------+--------------------------------------------------------+ + | Wrapper | Description | + +=================+========================================================+ + | TimeLimit | Limit the time steps of the environment. | + +-----------------+--------------------------------------------------------+ + | AutoReset | Reset the environment when the episode is done. | + +-----------------+--------------------------------------------------------+ + | ObsNormalize | Normalize the observation. | + +-----------------+--------------------------------------------------------+ + | RewardNormalize | Normalize the reward. | + +-----------------+--------------------------------------------------------+ + | CostNormalize | Normalize the cost. | + +-----------------+--------------------------------------------------------+ + | ActionScale | Scale the action. | + +-----------------+--------------------------------------------------------+ + | Unsqueeze | Unsqueeze the step result for single environment case. | + +-----------------+--------------------------------------------------------+ + + + Args: + obs_normalize (bool, optional): Whether to normalize the observation. Defaults to True. + reward_normalize (bool, optional): Whether to normalize the reward. Defaults to True. + cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True. + """ + if self._env.need_time_limit_wrapper: + self._env = TimeLimit(self._env, device=self._device, time_limit=1000) + if self._env.need_auto_reset_wrapper: + self._env = AutoReset(self._env, device=self._device) + if obs_normalize: + self._env = ObsNormalize(self._env, device=self._device) + if reward_normalize: + self._env = RewardNormalize(self._env, device=self._device) + if cost_normalize: + self._env = CostNormalize(self._env, device=self._device) + self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0) + self._env = ActionRepeat(self._env, times=2, device=self._device) + + if self._env.num_envs == 1: + self._env = Unsqueeze(self._env, device=self._device) + + @property + def latent_space(self) -> Box: + """Get the latent space.""" + return Box( + low=-np.inf, + high=np.inf, + shape=(self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2,), + ) + + def eval_policy( # pylint: disable=too-many-locals + self, + episode: int, + agent: ConstraintActorQCritic, + logger: Logger, + ) -> None: + for _ in range(episode): + ep_ret, ep_cost, ep_len = 0.0, 0.0, 0 + obs, _ = self._eval_env.reset() + obs = obs.to(self._device) + + done = False + while not done: + act = agent.step(obs, deterministic=True) + obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) + obs, reward, cost, terminated, truncated = ( + torch.as_tensor(x, dtype=torch.float32, device=self._device) + for x in (obs, reward, cost, terminated, truncated) + ) + ep_ret += info.get('original_reward', reward).cpu() + ep_cost += info.get('original_cost', cost).cpu() + ep_len += 1 + done = bool(terminated[0].item()) or bool(truncated[0].item()) + + logger.store( + { + 'Metrics/TestEpRet': ep_ret, + 'Metrics/TestEpCost': ep_cost, + 'Metrics/TestEpLen': ep_len, + }, + ) + + def pre_process(self, latent_model, concated_obs): + with torch.no_grad(): + feature = latent_model.encoder(concated_obs.last_state) + + if self.z2 is None: + z1_mean, z1_std = latent_model.z1_posterior_init(feature) + self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = latent_model.z2_posterior_init(self.z1) + self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std + else: + z1_mean, z1_std = latent_model.z1_posterior( + torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + ) + self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = latent_model.z2_posterior( + torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + ) + self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + return torch.cat([self.z1, self.z2], dim=-1).squeeze() + + def rollout( # pylint: disable=too-many-locals + self, + rollout_step: int, + agent: ConstraintActorQCritic, + latent_model: CostLatentModel, + buffer: OffPolicySequenceBuffer, + logger: Logger, + use_rand_action: bool, + ) -> None: + for step in range(rollout_step): + if not self._reset_sequence_queue: + buffer.reset_sequence_queue(self._current_obs) + self._observation_concator.reset_episode(self._current_obs) + self._reset_sequence_queue = True + + if use_rand_action: + act = act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore + else: + act = agent.step( + self.pre_process(latent_model, self._observation_concator), deterministic=False + ) + + next_obs, reward, cost, terminated, truncated, info = self.step(act) + step += info.get('num_step', 1) - 1 + + real_next_obs = next_obs.clone() + + self._observation_concator.append(next_obs, act) + + self._log_value(reward=reward, cost=cost, info=info) + + for idx, done in enumerate(torch.logical_or(terminated, truncated)): + if done: + self._log_metrics(logger, idx) + self._reset_log(idx) + self.z1 = None + self.z2 = None + self._reset_sequence_queue = False + if 'final_observation' in info: + real_next_obs[idx] = info['final_observation'][idx] + + buffer.store( + obs=real_next_obs, + act=act, + reward=reward, + cost=cost, + done=torch.logical_and(terminated, torch.logical_xor(terminated, truncated)), + ) + + self._current_obs = next_obs + + def _log_value( + self, + reward: torch.Tensor, + cost: torch.Tensor, + info: dict[str, Any], + ) -> None: + self._ep_ret += info.get('original_reward', reward).cpu() + self._ep_cost += info.get('original_cost', cost).cpu() + self._ep_len += info.get('num_step', 1) + + def _log_metrics(self, logger: Logger, idx: int) -> None: + logger.store( + { + 'Metrics/EpRet': self._ep_ret[idx], + 'Metrics/EpCost': self._ep_cost[idx], + 'Metrics/EpLen': self._ep_len[idx], + }, + ) + + def _reset_log(self, idx: int | None = None) -> None: + if idx is None: + self._ep_ret = torch.zeros(self._env.num_envs) + self._ep_cost = torch.zeros(self._env.num_envs) + self._ep_len = torch.zeros(self._env.num_envs) + else: + self._ep_ret[idx] = 0.0 + self._ep_cost[idx] = 0.0 + self._ep_len[idx] = 0.0 + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + obs, info = self._env.reset(seed=seed, options=options) + self._observation_concator.reset_episode(obs) + return obs, info diff --git a/omnisafe/algorithms/__init__.py b/omnisafe/algorithms/__init__.py index df6832226..4adf7d964 100644 --- a/omnisafe/algorithms/__init__.py +++ b/omnisafe/algorithms/__init__.py @@ -34,6 +34,7 @@ TD3PID, DDPGLag, SACLag, + SafeSLAC, TD3Lag, ) diff --git a/omnisafe/algorithms/off_policy/__init__.py b/omnisafe/algorithms/off_policy/__init__.py index 80e48e1a0..ddde90c8e 100644 --- a/omnisafe/algorithms/off_policy/__init__.py +++ b/omnisafe/algorithms/off_policy/__init__.py @@ -21,6 +21,7 @@ from omnisafe.algorithms.off_policy.sac import SAC from omnisafe.algorithms.off_policy.sac_lag import SACLag from omnisafe.algorithms.off_policy.sac_pid import SACPID +from omnisafe.algorithms.off_policy.safe_slac import SafeSLAC from omnisafe.algorithms.off_policy.td3 import TD3 from omnisafe.algorithms.off_policy.td3_lag import TD3Lag from omnisafe.algorithms.off_policy.td3_pid import TD3PID @@ -37,4 +38,5 @@ 'TD3PID', 'SACPID', 'CRABS', + 'SafeSLAC', ] diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py index 850c787b2..8c799e322 100644 --- a/omnisafe/algorithms/off_policy/ddpg.py +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -258,7 +258,7 @@ def learn(self) -> tuple[float, float, float]: for sample_step in range( epoch * self._samples_per_epoch, - (epoch + 1) * self._samples_per_epoch, + (epoch + 1) * self._samples_per_epoch + 1, ): step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums @@ -306,7 +306,7 @@ def learn(self) -> tuple[float, float, float]: self._logger.store( { - 'TotalEnvSteps': step + 1, + 'TotalEnvSteps': step, 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), 'Time/Epoch': (time.time() - epoch_time), diff --git a/omnisafe/algorithms/off_policy/safe_slac.py b/omnisafe/algorithms/off_policy/safe_slac.py new file mode 100644 index 000000000..61b9522f2 --- /dev/null +++ b/omnisafe/algorithms/off_policy/safe_slac.py @@ -0,0 +1,293 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the Safe Stochastic Latent Actor-Critic algorithm.""" + + +from __future__ import annotations + +import time + +import torch +from rich.progress import track +from torch import optim +from torch.nn.utils.clip_grad import clip_grad_norm_ + +from omnisafe.adapter.offpolicy_latent_adapter import OffPolicyLatentAdapter +from omnisafe.algorithms import registry +from omnisafe.algorithms.off_policy.sac_lag import SACLag +from omnisafe.common.buffer import OffPolicySequenceBuffer +from omnisafe.common.lagrange import Lagrange +from omnisafe.common.latent import CostLatentModel +from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic + + +@registry.register +# pylint: disable-next=too-many-instance-attributes, too-few-public-methods +class SafeSLAC(SACLag): + def _init(self) -> None: + if self._cfgs.algo_cfgs.auto_alpha: + self._target_entropy = -torch.prod(torch.Tensor(self._env.action_space.shape)).item() + self._log_alpha = torch.zeros(1, requires_grad=True, device=self._device) + + assert self._cfgs.model_cfgs.critic.lr is not None + self._alpha_optimizer = optim.Adam( + [self._log_alpha], + lr=self._cfgs.model_cfgs.critic.lr, + ) + else: + self._log_alpha = torch.log( + torch.tensor(self._cfgs.algo_cfgs.alpha, device=self._device), + ) + + self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs) + + self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( + obs_space=self._env.observation_space, + act_space=self._env.action_space, + size=self._cfgs.algo_cfgs.size, + batch_size=self._cfgs.algo_cfgs.batch_size, + device=self._device, + num_sequences=self._cfgs.algo_cfgs.num_sequences, + ) + self._is_latent_model_init_learned = False + + def _init_env(self) -> None: + self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( + self._env_id, + self._cfgs.train_cfgs.vector_env_nums, + self._seed, + self._cfgs, + ) + assert ( + self._cfgs.algo_cfgs.steps_per_epoch % self._cfgs.train_cfgs.vector_env_nums == 0 + ), 'The number of steps per epoch is not divisible by the number of environments.' + + assert ( + int(self._cfgs.train_cfgs.total_steps) % self._cfgs.algo_cfgs.steps_per_epoch == 0 + ), 'The total number of steps is not divisible by the number of steps per epoch.' + self._epochs: int = int( + self._cfgs.train_cfgs.total_steps // self._cfgs.algo_cfgs.steps_per_epoch, + ) + self._epoch: int = 0 + self._steps_per_epoch: int = ( + self._cfgs.algo_cfgs.steps_per_epoch // self._cfgs.train_cfgs.vector_env_nums + ) + + self._update_cycle: int = self._cfgs.algo_cfgs.update_cycle + assert ( + self._steps_per_epoch % self._update_cycle == 0 + ), 'The number of steps per epoch is not divisible by the number of steps per sample.' + self._samples_per_epoch: int = self._steps_per_epoch // self._update_cycle + self._update_count: int = 0 + self._update_latent_count = 0 + + def _init_model(self) -> None: + self._cfgs.model_cfgs.critic['num_critics'] = 2 + + self._latent_model = CostLatentModel( + obs_shape=self._env.observation_space.shape, + act_shape=self._env.action_space.shape, + feature_dim=self._cfgs.algo_cfgs.feature_dim, + latent_dim_1=self._cfgs.algo_cfgs.latent_dim_1, + latent_dim_2=self._cfgs.algo_cfgs.latent_dim_2, + hidden_sizes=self._cfgs.algo_cfgs.hidden_sizes, + image_noise=self._cfgs.algo_cfgs.image_noise, + ).to(self._device) + self._update_latent_count = 0 + + self._actor_critic: ConstraintActorQCritic = ConstraintActorQCritic( + obs_space=self._env.latent_space, + act_space=self._env.action_space, + model_cfgs=self._cfgs.model_cfgs, + epochs=self._epochs, + ).to(self._device) + + self._actor_critic = torch.compile(self._actor_critic) + self._latent_model = torch.compile(self._latent_model) + + self._latent_model_optimizer = optim.Adam( + self._latent_model.parameters(), + lr=1e-4, + ) + + def learn(self) -> tuple[float, float, float]: + """This is main function for algorithm update. + + It is divided into the following steps: + + - :meth:`rollout`: collect interactive data from environment. + - :meth:`update`: perform actor/critic updates. + - :meth:`log`: epoch/update information for visualization and terminal log print. + + Returns: + ep_ret: average episode return in final epoch. + ep_cost: average episode cost in final epoch. + ep_len: average episode length in final epoch. + """ + self._logger.log('INFO: Start training') + start_time = time.time() + step = 0 + for epoch in range(self._epochs): + self._epoch = epoch + rollout_time = 0.0 + update_time = 0.0 + epoch_time = time.time() + + for sample_step in range( + epoch * self._samples_per_epoch, + (epoch + 1) * self._samples_per_epoch + 1, + ): + step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums + + rollout_start = time.time() + # set noise for exploration + if self._cfgs.algo_cfgs.use_exploration_noise: + self._actor_critic.actor.noise = self._cfgs.algo_cfgs.exploration_noise + + # collect data from environment + self._env.rollout( + rollout_step=self._update_cycle, + agent=self._actor_critic, + buffer=self._buf, + logger=self._logger, + latent_model=self._latent_model, + use_rand_action=(step <= self._cfgs.algo_cfgs.start_learning_steps), + ) + rollout_time += time.time() - rollout_start + + # update parameters + update_start = time.time() + if step > self._cfgs.algo_cfgs.start_learning_steps: + self._update() + # if we haven't updated the network, log 0 for the loss + else: + self._log_when_not_update() + update_time += time.time() - update_start + + eval_start = time.time() + self._env.eval_policy( + episode=self._cfgs.train_cfgs.eval_episodes, + agent=self._actor_critic, + logger=self._logger, + ) + eval_time = time.time() - eval_start + + self._logger.store({'Time/Update': update_time}) + self._logger.store({'Time/Rollout': rollout_time}) + self._logger.store({'Time/Evaluate': eval_time}) + + if ( + step > self._cfgs.algo_cfgs.start_learning_steps + and self._cfgs.model_cfgs.linear_lr_decay + ): + self._actor_critic.actor_scheduler.step() + + self._logger.store( + { + 'TotalEnvSteps': step, + 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), + 'Time/Total': (time.time() - start_time), + 'Time/Epoch': (time.time() - epoch_time), + 'Train/Epoch': epoch, + 'Train/LR': self._actor_critic.actor_scheduler.get_last_lr()[0], + }, + ) + + self._logger.dump_tabular() + + # save model to disk + if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: + self._logger.torch_save() + + ep_ret = self._logger.get_stats('Metrics/EpRet')[0] + ep_cost = self._logger.get_stats('Metrics/EpCost')[0] + ep_len = self._logger.get_stats('Metrics/EpLen')[0] + self._logger.close() + + return ep_ret, ep_cost, ep_len + + def _prepare_batch(self, obs_, action_): + with torch.no_grad(): + feature_ = self._latent_model.encoder(obs_) + z_ = torch.cat(self._latent_model.sample_posterior(feature_, action_)[2:4], dim=-1) + + z, next_z = z_[:, -2], z_[:, -1] + action = action_[:, -1] + + return z, next_z, action + + def _update(self) -> None: + if not self._is_latent_model_init_learned: + for _ in track( + range(self._cfgs.algo_cfgs.latent_model_init_learning_steps), + description='initial updating of latent model...', + ): + self._update_latent_model() + self._is_latent_model_init_learned = True + + Jc = self._logger.get_stats('Metrics/EpCost')[0] + if self._epoch > self._cfgs.algo_cfgs.warmup_epochs: + self._lagrange.update_lagrange_multiplier(Jc) + self._logger.store( + { + 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), + }, + ) + + for _ in range(self._cfgs.algo_cfgs.update_iters): + self._update_latent_model() + + data = self._buf.sample_batch(64) + self._update_count += 1 + obs_, act_, reward, cost, done = ( + data['obs'], + data['act'], + data['reward'][:, -1].squeeze(), + data['cost'][:, -1].squeeze(), + data['done'][:, -1].squeeze(), + ) + obs, next_obs, act = self._prepare_batch(obs_, act_) + self._update_reward_critic(obs, act, reward, done, next_obs) + self._update_cost_critic(obs, act, cost, done, next_obs) + + if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0: + self._update_actor(obs) + self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak) + + def _update_latent_model( + self, + ): + data = self._buf.sample_batch(32) + obs, act, reward, cost, done = ( + data['obs'], + data['act'], + data['reward'], + data['cost'], + data['done'], + ) + + self._update_latent_count += 1 + loss_kld, loss_image, loss_reward, loss_cost = self._latent_model.calculate_loss( + obs, act, reward, done, cost + ) + + self._latent_model_optimizer.zero_grad() + (loss_kld + loss_image + loss_reward + loss_cost).backward() + if self._cfgs.algo_cfgs.max_grad_norm: + clip_grad_norm_( + self._latent_model.parameters(), + self._cfgs.algo_cfgs.max_grad_norm, + ) + self._latent_model_optimizer.step() diff --git a/omnisafe/common/buffer/__init__.py b/omnisafe/common/buffer/__init__.py index 669770849..4e47a06db 100644 --- a/omnisafe/common/buffer/__init__.py +++ b/omnisafe/common/buffer/__init__.py @@ -15,7 +15,7 @@ """Implementation of Buffer.""" from omnisafe.common.buffer.base import BaseBuffer -from omnisafe.common.buffer.offpolicy_buffer import OffPolicyBuffer +from omnisafe.common.buffer.offpolicy_buffer import OffPolicyBuffer, OffPolicySequenceBuffer from omnisafe.common.buffer.onpolicy_buffer import OnPolicyBuffer from omnisafe.common.buffer.vector_offpolicy_buffer import VectorOffPolicyBuffer from omnisafe.common.buffer.vector_onpolicy_buffer import VectorOnPolicyBuffer @@ -24,6 +24,7 @@ __all__ = [ 'BaseBuffer', 'OffPolicyBuffer', + 'OffPolicySequenceBuffer', 'OnPolicyBuffer', 'VectorOffPolicyBuffer', 'VectorOnPolicyBuffer', diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index 08864ecb0..0e521eacf 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -22,6 +22,7 @@ from gymnasium.spaces import Box from omnisafe.typing import DEVICE_CPU, OmnisafeSpace +from omnisafe.utils.tools import SequenceQueue class BaseBuffer(ABC): @@ -132,3 +133,53 @@ def store(self, **data: torch.Tensor) -> None: Args: data (torch.Tensor): The data to store. """ + + +class BaseSequenceBuffer(BaseBuffer): + def __init__( + self, + obs_space: OmnisafeSpace, + act_space: OmnisafeSpace, + size: int, + num_sequences: int, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize an instance of :class:`BaseBuffer`.""" + self._device: torch.device = device + self._num_sequences = num_sequences + if isinstance(obs_space, Box): + obs_buf = [None] * size + else: + raise NotImplementedError + if isinstance(act_space, Box): + act_buf = torch.zeros( + (size, num_sequences, *act_space.shape), + dtype=torch.float32, + device=device, + ) + else: + raise NotImplementedError + + self.data: dict[str, torch.Tensor | list] = { + 'obs': obs_buf, + 'act': act_buf, + 'reward': torch.zeros(size, num_sequences, 1, dtype=torch.float32, device=device), + 'cost': torch.zeros(size, num_sequences, 1, dtype=torch.float32, device=device), + 'done': torch.zeros(size, num_sequences, 1, dtype=torch.float32, device=device), + } + + self.sequence_queue = SequenceQueue( + obs_space=obs_space, + num_sequences=num_sequences, + device=device, + ) + + self._size: int = size + self._observation_shape = obs_space.shape + + def add_field(self, name: str, shape: tuple[int, ...], dtype: torch.dtype) -> None: + self.data[name] = torch.zeros( + (self._size, self._num_sequences, *shape), + dtype=dtype, + device=self._device, + ) diff --git a/omnisafe/common/buffer/offpolicy_buffer.py b/omnisafe/common/buffer/offpolicy_buffer.py index 18759d41e..2dd3468b7 100644 --- a/omnisafe/common/buffer/offpolicy_buffer.py +++ b/omnisafe/common/buffer/offpolicy_buffer.py @@ -16,10 +16,11 @@ from __future__ import annotations +import numpy as np import torch from gymnasium.spaces import Box -from omnisafe.common.buffer.base import BaseBuffer +from omnisafe.common.buffer.base import BaseBuffer, BaseSequenceBuffer from omnisafe.typing import DEVICE_CPU, OmnisafeSpace @@ -119,3 +120,79 @@ def sample_batch(self) -> dict[str, torch.Tensor]: """ idxs = torch.randint(0, self._size, (self._batch_size,)) return {key: value[idxs] for key, value in self.data.items()} + + +class OffPolicySequenceBuffer(BaseSequenceBuffer): + def __init__( # pylint: disable=too-many-arguments + self, + obs_space: OmnisafeSpace, + act_space: OmnisafeSpace, + size: int, + batch_size: int, + num_sequences: int, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize an instance of :class:`OffPolicySequenceBuffer`.""" + super().__init__(obs_space, act_space, size, num_sequences, device) + + self._ptr: int = 0 + self._size: int = 0 + self._max_size: int = size + self._batch_size: int = batch_size + + assert ( + self._max_size > self._batch_size + ), 'The size of the buffer must be larger than the batch size.' + + @property + def max_size(self) -> int: + """Return the max size of the buffer.""" + return self._max_size + + @property + def size(self) -> int: + """Return the current size of the buffer.""" + return self._size + + @property + def batch_size(self) -> int: + """Return the batch size of the buffer.""" + return self._batch_size + + def store(self, **data: torch.Tensor) -> None: + """Store data into the buffer. + + .. hint:: + The ReplayBuffer is a circular buffer. When the buffer is full, the oldest data will be + overwritten. + + Args: + data (torch.Tensor): The data to be stored. + """ + self.sequence_queue.append(**data) + if self.sequence_queue.is_full(): + sequece_data = self.sequence_queue.get() + for key, value in sequece_data.items(): + self.data[key][self._ptr] = value + self._ptr = (self._ptr + 1) % self._max_size + self._size = min(self._size + 1, self._max_size) + + def sample_batch(self, batch_size: int | None) -> dict[str, torch.Tensor]: + """Sample a batch of data from the buffer. + + Returns: + The sampled batch of data. + """ + batch_size = batch_size or self._batch_size + idxs = torch.randint(0, self._size, (batch_size,)) + returns = {key: value[idxs] for key, value in self.data.items() if key != 'obs'} + obs = np.empty((batch_size, self._num_sequences + 1, *self._observation_shape)) + for i, idx in enumerate(idxs): + obs[i, ...] = self.data['obs'][idx] + obs = torch.tensor(obs, dtype=torch.float32, device=self._device) + returns.update({'obs': obs}) + return returns + + def reset_sequence_queue(self, obs: torch.Tensor) -> None: + """Reset the sequence queue.""" + self.sequence_queue.reset_sequence_queue(obs) diff --git a/omnisafe/common/latent.py b/omnisafe/common/latent.py new file mode 100644 index 000000000..18e8ee921 --- /dev/null +++ b/omnisafe/common/latent.py @@ -0,0 +1,344 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + + +def calculate_kl_divergence(p_mean, p_std, q_mean, q_std): + var_ratio = (p_std / q_std).pow_(2) + t1 = ((p_mean - q_mean) / q_std).pow_(2) + return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) + + +def build_mlp( + input_dim, + output_dim, + hidden_sizes=None, + hidden_activation=nn.Tanh(), + output_activation=None, +): + if hidden_sizes is None: + hidden_sizes = [64, 64] + layers = [] + units = input_dim + for next_units in hidden_sizes: + layers.append(nn.Linear(units, next_units)) + layers.append(hidden_activation) + units = next_units + model = nn.Sequential(*layers) + model.add_module('last_linear', nn.Linear(units, output_dim)) + if output_activation is not None: + model.add_module('output_activation', output_activation) + return model + + +def initialize_weight(m): + if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + nn.init.xavier_uniform_(m.weight, gain=1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class FixedGaussian(torch.nn.Module): + """ + Fixed diagonal gaussian distribution. + """ + + def __init__(self, output_dim, std) -> None: + super().__init__() + self.output_dim = output_dim + self.std = std + + def forward(self, x): + mean = torch.zeros(x.size(0), self.output_dim, device=x.device) + std = torch.ones(x.size(0), self.output_dim, device=x.device).mul_(self.std) + return mean, std + + +class Gaussian(torch.nn.Module): + """ + Diagonal gaussian distribution with state dependent variances. + """ + + def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + super().__init__() + self.net = build_mlp( + input_dim=input_dim, + output_dim=2 * output_dim, + hidden_sizes=hidden_sizes, + hidden_activation=nn.ELU(), + ).apply(initialize_weight) + + def forward(self, x): + if x.ndim == 3: + B, S, _ = x.size() + x = self.net(x.view(B * S, _)).view(B, S, -1) + else: + x = self.net(x) + mean, std = torch.chunk(x, 2, dim=-1) + std = F.softplus(std) + 1e-5 + return mean, std + + +class Bernoulli(torch.nn.Module): + """ + Diagonal gaussian distribution with state dependent variances. + """ + + def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + super().__init__() + self.net = build_mlp( + input_dim=input_dim, + output_dim=output_dim, + hidden_sizes=hidden_sizes, + hidden_activation=nn.ELU(), + ).apply(initialize_weight) + + def forward(self, x): + if x.ndim == 3: + B, S, _ = x.size() + x = self.net(x.view(B * S, _)).view(B, S, -1) + else: + x = self.net(x) + return torch.sigmoid(x) + + +class Decoder(torch.nn.Module): + """ + Decoder. + """ + + def __init__(self, input_dim=288, output_dim=3, std=1.0) -> None: + super().__init__() + + self.net = nn.Sequential( + # (32+256, 1, 1) -> (256, 4, 4) + nn.ConvTranspose2d(input_dim, 256, 4), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (256, 4, 4) -> (128, 8, 8) + nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (128, 8, 8) -> (64, 16, 16) + nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (64, 16, 16) -> (32, 32, 32) + nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (32, 32, 32) -> (3, 64, 64) + nn.ConvTranspose2d(32, output_dim, 5, 2, 2, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + ).apply(initialize_weight) + self.std = std + + def forward(self, x): + B, S, latent_dim = x.size() + x = x.view(B * S, latent_dim, 1, 1) + x = self.net(x) + _, C, W, H = x.size() + x = x.view(B, S, C, W, H) + return x, torch.ones_like(x).mul_(self.std) + + +class Encoder(torch.nn.Module): + """ + Encoder. + """ + + def __init__(self, input_dim=3, output_dim=256) -> None: + super().__init__() + + self.net = nn.Sequential( + # (3, 64, 64) -> (32, 32, 32) + nn.Conv2d(input_dim, 32, 5, 2, 2), + nn.ELU(inplace=True), + # (32, 32, 32) -> (64, 16, 16) + nn.Conv2d(32, 64, 3, 2, 1), + nn.ELU(inplace=True), + # (64, 16, 16) -> (128, 8, 8) + nn.Conv2d(64, 128, 3, 2, 1), + nn.ELU(inplace=True), + # (128, 8, 8) -> (256, 4, 4) + nn.Conv2d(128, 256, 3, 2, 1), + nn.ELU(inplace=True), + # (256, 4, 4) -> (256, 1, 1) + nn.Conv2d(256, output_dim, 4), + nn.ELU(inplace=True), + ).apply(initialize_weight) + + def forward(self, x): + B, S, C, H, W = x.size() + x = x.view(B * S, C, H, W) + x = self.net(x) + return x.view(B, S, -1) + + +class CostLatentModel(torch.nn.Module): + """ + Stochastic latent variable model to estimate latent dynamics, reward and cost. + """ + + def __init__( + self, + obs_shape, + act_shape, + feature_dim=256, + latent_dim_1=32, + latent_dim_2=256, + hidden_sizes=(256, 256), + image_noise=0.1, + ) -> None: + super().__init__() + self.bceloss = torch.nn.BCELoss(reduction='none') + # p(z1(0)) = N(0, I) + self.z1_prior_init = FixedGaussian(latent_dim_1, 1.0) + # p(z2(0) | z1(0)) + self.z2_prior_init = Gaussian(latent_dim_1, latent_dim_2, hidden_sizes) + # p(z1(t+1) | z2(t), a(t)) + self.z1_prior = Gaussian( + latent_dim_2 + act_shape[0], + latent_dim_1, + hidden_sizes, + ) + # p(z2(t+1) | z1(t+1), z2(t), a(t)) + self.z2_prior = Gaussian( + latent_dim_1 + latent_dim_2 + act_shape[0], + latent_dim_2, + hidden_sizes, + ) + + # q(z1(0) | feat(0)) + self.z1_posterior_init = Gaussian(feature_dim, latent_dim_1, hidden_sizes) + # q(z2(0) | z1(0)) = p(z2(0) | z1(0)) + self.z2_posterior_init = self.z2_prior_init + # q(z1(t+1) | feat(t+1), z2(t), a(t)) + self.z1_posterior = Gaussian( + feature_dim + latent_dim_2 + act_shape[0], + latent_dim_1, + hidden_sizes, + ) + # q(z2(t+1) | z1(t+1), z2(t), a(t)) = p(z2(t+1) | z1(t+1), z2(t), a(t)) + self.z2_posterior = self.z2_prior + + # p(r(t) | z1(t), z2(t), a(t), z1(t+1), z2(t+1)) + self.reward = Gaussian( + 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + 1, + hidden_sizes, + ) + + self.cost = Bernoulli( + 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + 1, + hidden_sizes, + ) + + # feat(t) = Encoder(x(t)) + self.encoder = Encoder(obs_shape[0], feature_dim) + # p(x(t) | z1(t), z2(t)) + self.decoder = Decoder( + latent_dim_1 + latent_dim_2, + obs_shape[0], + std=np.sqrt(image_noise), + ) + self.apply(initialize_weight) + + def sample_prior(self, actions_, z2_post_): + # p(z1(0)) = N(0, I) + z1_mean_init, z1_std_init = self.z1_prior_init(actions_[:, 0]) + # p(z1(t) | z2(t-1), a(t-1)) + z1_mean_, z1_std_ = self.z1_prior( + torch.cat([z2_post_[:, : actions_.size(1)], actions_], dim=-1) + ) + # Concatenate initial and consecutive latent variables + z1_mean_ = torch.cat([z1_mean_init.unsqueeze(1), z1_mean_], dim=1) + z1_std_ = torch.cat([z1_std_init.unsqueeze(1), z1_std_], dim=1) + return (z1_mean_, z1_std_) + + def sample_posterior(self, features_, actions_): + # p(z1(0)) = N(0, I) + z1_mean, z1_std = self.z1_posterior_init(features_[:, 0]) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + # p(z2(0) | z1(0)) + z2_mean, z2_std = self.z2_posterior_init(z1) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + z1_mean_ = [z1_mean] + z1_std_ = [z1_std] + z1_ = [z1] + z2_ = [z2] + + for t in range(1, actions_.size(1) + 1): + # q(z1(t) | feat(t), z2(t-1), a(t-1)) + z1_mean, z1_std = self.z1_posterior( + torch.cat([features_[:, t], z2, actions_[:, t - 1]], dim=1) + ) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + # q(z2(t) | z1(t), z2(t-1), a(t-1)) + z2_mean, z2_std = self.z2_posterior(torch.cat([z1, z2, actions_[:, t - 1]], dim=1)) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + z1_mean_.append(z1_mean) + z1_std_.append(z1_std) + z1_.append(z1) + z2_.append(z2) + + z1_mean_ = torch.stack(z1_mean_, dim=1) + z1_std_ = torch.stack(z1_std_, dim=1) + z1_ = torch.stack(z1_, dim=1) + z2_ = torch.stack(z2_, dim=1) + return (z1_mean_, z1_std_, z1_, z2_) + + # + def calculate_loss(self, state_, action_, reward_, done_, cost_): + # Calculate the sequence of features. + feature_ = self.encoder(state_) + + # Sample from latent variable model. + z1_mean_post_, z1_std_post_, z1_, z2_ = self.sample_posterior(feature_, action_) + z1_mean_pri_, z1_std_pri_ = self.sample_prior(action_, z2_) + + # Calculate KL divergence loss. + loss_kld = ( + calculate_kl_divergence(z1_mean_post_, z1_std_post_, z1_mean_pri_, z1_std_pri_) + .mean(dim=0) + .sum() + ) + + # Prediction loss of images. + z_ = torch.cat([z1_, z2_], dim=-1) + state_mean_, state_std_ = self.decoder(z_) + state_noise_ = (state_ - state_mean_) / (state_std_ + 1e-8) + log_likelihood_ = (-0.5 * state_noise_.pow(2) - state_std_.log()) - 0.5 * math.log( + 2 * math.pi + ) + loss_image = -log_likelihood_.mean(dim=0).sum() + + # Prediction loss of rewards. + x = torch.cat([z_[:, :-1], action_, z_[:, 1:]], dim=-1) + B, S, X = x.shape + reward_mean_, reward_std_ = self.reward(x.view(B * S, X)) + reward_mean_ = reward_mean_.view(B, S, 1) + reward_std_ = reward_std_.view(B, S, 1) + reward_noise_ = (reward_ - reward_mean_) / (reward_std_ + 1e-8) + log_likelihood_reward_ = (-0.5 * reward_noise_.pow(2) - reward_std_.log()) - 0.5 * math.log( + 2 * math.pi + ) + loss_reward = -log_likelihood_reward_.mul_(1 - done_).mean(dim=0).sum() + + p = self.cost(x.view(B * S, X)).view(B, S, 1) + q = 1 - p + weight_p = 100 + binary_cost_ = torch.sign(cost_) + loss_cost = ( + -30 + * ( + weight_p * binary_cost_ * torch.log(p + 1e-6) + + (1 - binary_cost_) * torch.log(q + 1e-6) + ) + .mean(dim=0) + .sum() + ) + + return loss_kld, loss_image, loss_reward, loss_cost diff --git a/omnisafe/configs/off-policy/SafeSLAC.yaml b/omnisafe/configs/off-policy/SafeSLAC.yaml new file mode 100644 index 000000000..a11626fe3 --- /dev/null +++ b/omnisafe/configs/off-policy/SafeSLAC.yaml @@ -0,0 +1,148 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +defaults: + # seed for random number generator + seed: 0 + # training configurations + train_cfgs: + # device to use for training, options: cpu, cuda, cuda:0, cuda:0,1, etc. + device: cpu + # number of threads for torch + torch_threads: 16 + # number of vectorized environments + vector_env_nums: 1 + # number of parallel agent, similar to a3c + parallel: 1 + # total number of steps to train + total_steps: 1000000 + # number of evaluate episodes + eval_episodes: 0 + # algorithm configurations + algo_cfgs: + # number of times each action is repeated in the environment + action_repeat: 2 + # initial learning steps for the latent model + latent_model_init_learning_steps: 30000 + # number of sequences used for training + num_sequences: 10 + # amount of noise added to input images as a form of augmentation + image_noise: 0.4 + # list of sizes for hidden layers of dynamics model + hidden_sizes: [256, 256] + # dimensionality of feature vectors initially + feature_dim: 256 + # dimensionality of the first latent space + latent_dim_1: 32 + # dimensionality of the second latent space + latent_dim_2: 200 + # dimensionality of feature vectors + feature_dim: 200 + # number of steps to update the policy + steps_per_epoch: 2000 + # number of steps per sample + update_cycle: 100 + # number of iterations to update the policy + update_iters: 100 + # The size of replay buffer + size: 2000000 + # The size of batch + batch_size: 64 + # normalize reward + reward_normalize: False + # normalize cost + cost_normalize: False + # normalize observation + obs_normalize: False + # max gradient norm + max_grad_norm: 40.0 + # use critic norm + use_critic_norm: False + # critic norm coefficient + critic_norm_coeff: 0.001 + # The soft update coefficient + polyak: 0.005 + # The discount factor of GAE + gamma: 0.995 + # Actor perdorm random action before `start_learning_steps` steps + start_learning_steps: 30000 + # The delay step of policy update + policy_delay: 2 + # Whether to use the exploration noise + use_exploration_noise: False + # The exploration noise + exploration_noise: 0.1 + # The policy noise + policy_noise: 0.2 + # policy_noise_clip + policy_noise_clip: 0.5 + # The value of alpha + alpha: 0.004 + # Whether to use auto alpha + auto_alpha: False + # use cost + use_cost: True + # warm up epoch + warmup_epochs: 100 + # logger configurations + logger_cfgs: + # use wandb for logging + use_wandb: False + # wandb project name + wandb_project: omnisafe + # use tensorboard for logging + use_tensorboard: True + # save model frequency + save_model_freq: 100 + # save logger path + log_dir: "./runs" + # save model path + window_lens: 10 + # model configurations + model_cfgs: + # weight initialization mode + weight_initialization_mode: "kaiming_uniform" + # actor type + actor_type: gaussian_sac + # linear learning rate decay + linear_lr_decay: False + # Configuration of Actor network + actor: + # Size of hidden layers + hidden_sizes: [256, 256] + # Activation function + activation: relu + # The learning rate of Actor network + lr: 0.000005 + # Configuration of Critic network + critic: + # The number of critic networks + num_critics: 2 + # Size of hidden layers + hidden_sizes: [256, 256] + # Activation function + activation: relu + # The learning rate of Critic network + lr: 0.001 + # lagrangian configurations + lagrange_cfgs: + # Tolerance of constraint violation + cost_limit: 25.0 + # Initial value of lagrangian multiplier + lagrangian_multiplier_init: 0.000 + # Learning rate of lagrangian multiplier + lambda_lr: 0.0002 + # Type of lagrangian optimizer + lambda_optimizer: "Adam" diff --git a/omnisafe/envs/__init__.py b/omnisafe/envs/__init__.py index df8b94c7c..53c488ffd 100644 --- a/omnisafe/envs/__init__.py +++ b/omnisafe/envs/__init__.py @@ -25,3 +25,4 @@ from omnisafe.envs.safety_gymnasium_env import SafetyGymnasiumEnv from omnisafe.envs.safety_gymnasium_modelbased import SafetyGymnasiumModelBased from omnisafe.envs.safety_isaac_gym_env import SafetyIsaacGymEnv +from omnisafe.envs.safety_gymnasium_vision_env import SafetyGymnasiumVisionEnv diff --git a/omnisafe/envs/safety_gymnasium_vision_env.py b/omnisafe/envs/safety_gymnasium_vision_env.py new file mode 100644 index 000000000..778ddacb4 --- /dev/null +++ b/omnisafe/envs/safety_gymnasium_vision_env.py @@ -0,0 +1,194 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Environments in the Vision-Based Safety-Gymnasium.""" + +from __future__ import annotations + +import os +from typing import Any, ClassVar + +import numpy as np +import safety_gymnasium +import torch + +from omnisafe.envs.core import CMDP, env_register +from omnisafe.typing import DEVICE_CPU, Box + + +@env_register +class SafetyGymnasiumVisionEnv(CMDP): + need_auto_reset_wrapper: bool = False + need_time_limit_wrapper: bool = False + + _support_envs: ClassVar[list[str]] = [ + 'SafetyCarGoal1Vision-v0', + 'SafetyPointGoal1Vision-v0', + 'SafetyPointButton1Vision-v0', + 'SafetyPointPush1Vision-v0', + 'SafetyPointGoal2Vision-v0', + 'SafetyPointButton2Vision-v0', + 'SafetyPointPush2Vision-v0', + ] + + def __init__( + self, + env_id: str, + num_envs: int = 1, + device: torch.device = DEVICE_CPU, + **kwargs: Any, + ) -> None: + """Initialize an instance of :class:`SafetyGymnasiumVisionEnv`.""" + super().__init__(env_id) + self._num_envs = num_envs + self._device = torch.device(device) + if 'MUJOCO_GL' not in os.environ: + os.environ['MUJOCO_GL'] = 'osmesa' + self.need_time_limit_wrapper = True + self.need_auto_reset_wrapper = True + self._env = safety_gymnasium.make( + id=env_id, + autoreset=True, + render_mode='rgb_array', + camera_name='vision', + width=64, + height=64, + **kwargs, + ) + + self._observation_space = Box(shape=(3, 64, 64), low=0, high=255, dtype=np.uint8) + self._action_space = self._env.action_space + + self._metadata = self._env.metadata + + def step( + self, + action: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict[str, Any], + ]: + """Step the environment. + + .. note:: + OmniSafe uses auto reset wrapper to reset the environment when the episode is + terminated. So the ``obs`` will be the first observation of the next episode. And the + true ``final_observation`` in ``info`` will be stored in the ``final_observation`` key + of ``info``. + + Args: + action (torch.Tensor): Action to take. + + Returns: + observation: The agent's observation of the current environment. + reward: The amount of reward returned after previous action. + cost: The amount of cost returned after previous action. + terminated: Whether the episode has ended. + truncated: Whether the episode has been truncated due to a time limit. + info: Some information logged by the environment. + """ + obs, reward, cost, terminated, truncated, info = self._env.step( + action.detach().cpu().numpy(), + ) + + reward, cost, terminated, truncated = ( + torch.as_tensor(x, dtype=torch.float32, device=self._device) + for x in (reward, cost, terminated, truncated) + ) + obs = ( + torch.as_tensor(obs['vision'].copy(), dtype=torch.uint8, device=self._device) + .float() + .div_(255.0) + .transpose(0, -1) + ) + if 'final_observation' in info: + info['final_observation'] = np.array( + [ + array if array is not None else np.zeros(obs.shape[-1]) + for array in info['final_observation']['vision'].copy() + ], + ) + info['final_observation'] = ( + torch.as_tensor( + info['final_observation'], + dtype=torch.int8, + device=self._device, + ) + .float() + .div_(255.0) + .transpose(0, -1) + ) + + return obs, reward, cost, terminated, truncated, info + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Reset the environment. + + Args: + seed (int, optional): The random seed. Defaults to None. + options (dict[str, Any], optional): The options for the environment. Defaults to None. + + Returns: + observation: Agent's observation of the current environment. + info: Some information logged by the environment. + """ + obs, info = self._env.reset(seed=seed, options=options) + return ( + torch.as_tensor(obs['vision'].copy(), dtype=torch.uint8, device=self._device) + .float() + .div_(255.0) + .transpose(0, -1), + info, + ) + + def set_seed(self, seed: int) -> None: + """Set the seed for the environment. + + Args: + seed (int): Seed to set. + """ + self.reset(seed=seed) + + def sample_action(self) -> torch.Tensor: + """Sample a random action. + + Returns: + A random action. + """ + return torch.as_tensor( + self._env.action_space.sample(), + dtype=torch.float32, + device=self._device, + ) + + def render(self) -> Any: + """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. + + Returns: + The render frames: we recommend to use `np.ndarray` + which could construct video by moviepy. + """ + return self._env.render() + + def close(self) -> None: + """Close the environment.""" + self._env.close() diff --git a/omnisafe/utils/model.py b/omnisafe/utils/model.py index 4d01c5ff5..5e1b44a0c 100644 --- a/omnisafe/utils/model.py +++ b/omnisafe/utils/model.py @@ -16,10 +16,13 @@ from __future__ import annotations +from collections import deque + import numpy as np +import torch from torch import nn -from omnisafe.typing import Activation, InitFunction +from omnisafe.typing import DEVICE_CPU, Activation, InitFunction def initialize_layer(init_function: InitFunction, layer: nn.Linear) -> None: @@ -109,3 +112,43 @@ def build_mlp_network( initialize_layer(weight_initialization_mode, affine_layer) layers += [affine_layer, act_fn()] return nn.Sequential(*layers) + + +class ObservationConcator: + def __init__(self, state_shape, action_shape, num_sequences, device=DEVICE_CPU) -> None: + self.state_shape = state_shape + self.action_shape = action_shape + self.num_sequences = num_sequences + self.device = device + + def reset_episode(self, state): + self._state = deque(maxlen=self.num_sequences) + self._action = deque(maxlen=self.num_sequences - 1) + for _ in range(self.num_sequences - 1): + self._state.append( + torch.zeros(self.state_shape, dtype=torch.float32, device=self.device), + ) + self._action.append( + torch.zeros(self.action_shape, dtype=torch.float32, device=self.device), + ) + self._state.append(state) + + def append(self, state, action): + self._state.append(state) + self._action.append(action) + + @property + def state(self): + return self._state[None, ...] + + @property + def last_state(self): + return self._state[-1][None, ...] + + @property + def action(self): + return self._action.reshape(1, -1) + + @property + def last_action(self): + return self._action[-1] diff --git a/omnisafe/utils/tools.py b/omnisafe/utils/tools.py index 2c0c626eb..ba8e18694 100644 --- a/omnisafe/utils/tools.py +++ b/omnisafe/utils/tools.py @@ -21,6 +21,7 @@ import os import random import sys +from collections import deque from typing import Any import numpy as np @@ -29,7 +30,7 @@ import yaml from rich.console import Console -from omnisafe.typing import DEVICE_CPU +from omnisafe.typing import DEVICE_CPU, OmnisafeSpace def get_flat_params_from(model: torch.nn.Module) -> torch.Tensor: @@ -356,3 +357,76 @@ def get_device(device: torch.device | str | int = DEVICE_CPU) -> torch.device: return torch.device('cpu') return device + + +def create_feature_actions(feature_, action_): + N = feature_.size(0) + # Flatten sequence of features. + # f (batch_size, (num_sequences)*feature_dim) + f = feature_[:, :-1].view(N, -1) + n_f = feature_[:, 1:].view(N, -1) + # Flatten sequence of actions. + a = action_[:, :-1].view(N, -1) + n_a = action_[:, 1:].view(N, -1) + # Concatenate feature and action. + fa = torch.cat([f, a], dim=-1) + n_fa = torch.cat([n_f, n_a], dim=-1) + return fa, n_fa + + +class LazyFrames: + def __init__(self, frames) -> None: + self._frames = list(frames) + + def __array__(self, dtype): + return np.array(self._frames, dtype=dtype) + + def __len__(self) -> int: + return len(self._frames) + + +class SequenceQueue: + def __init__(self, obs_space: OmnisafeSpace, num_sequences: int = 8, device=DEVICE_CPU) -> None: + self.num_sequences = num_sequences + self._reset_episode = False + self._obs_space = obs_space + self._device = device + self.data = {} + self.data['obs'] = deque(maxlen=self.num_sequences + 1) + self.data['act'] = deque(maxlen=self.num_sequences) + self.data['reward'] = deque(maxlen=self.num_sequences) + self.data['done'] = deque(maxlen=self.num_sequences) + self.data['cost'] = deque(maxlen=self.num_sequences) + + def reset_sequence_queue(self, obs): + for k in self.data: + self.data[k].clear() + self._reset_episode = True + self.data['obs'].append(obs.detach().cpu().numpy()) + + def append(self, **data: torch.Tensor): + assert self._reset_episode, self._reset_episode + for key, value in data.items(): + self.data[key].append(value.detach().cpu().numpy()) + + def _process_get(self, key: str) -> LazyFrames | torch.Tensor: + if key == 'obs': + return np.array(LazyFrames(self.data['obs']), dtype=np.float32).swapaxes(0, 1).squeeze() + else: + return torch.tensor( + np.array(self.data[key], dtype=np.float32), + dtype=torch.float32, + device=self._device, + ) + + def get(self): + return {key: self._process_get(key) for key in self.data} + + def is_empty(self): + return len(self.data['reward']) == 0 + + def is_full(self): + return len(self.data['reward']) == self.num_sequences + + def __len__(self) -> int: + return len(self.data['reward']) From 9fd5fd7997aaddda8775cbda6e136fc095b9c056 Mon Sep 17 00:00:00 2001 From: Gaiejj Date: Sun, 5 May 2024 21:49:08 +0800 Subject: [PATCH 2/3] style: improve code style --- docs/source/spelling_wordlist.txt | 12 + omnisafe/adapter/__init__.py | 1 + omnisafe/adapter/offpolicy_latent_adapter.py | 106 ++++- omnisafe/algorithms/off_policy/safe_slac.py | 34 +- omnisafe/common/buffer/base.py | 43 +- omnisafe/common/buffer/offpolicy_buffer.py | 2 + omnisafe/common/latent.py | 396 +++++++++++++------ omnisafe/common/normalizer.py | 6 +- omnisafe/configs/off-policy/SafeSLAC.yaml | 10 +- omnisafe/envs/safety_gymnasium_vision_env.py | 21 + omnisafe/utils/model.py | 53 ++- omnisafe/utils/tools.py | 131 ++++-- 12 files changed, 619 insertions(+), 196 deletions(-) diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 460cabd1a..75d324245 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -486,3 +486,15 @@ UpdateDynamics mathbb meger Jupyter +LazyFrames +SLAC +Leibler +Kullback +slac +Tal +Nils +Simão +Hogewind +Yannick +Kachman +Thiago diff --git a/omnisafe/adapter/__init__.py b/omnisafe/adapter/__init__.py index ba768a7eb..eae6697d2 100644 --- a/omnisafe/adapter/__init__.py +++ b/omnisafe/adapter/__init__.py @@ -18,6 +18,7 @@ from omnisafe.adapter.modelbased_adapter import ModelBasedAdapter from omnisafe.adapter.offline_adapter import OfflineAdapter from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter +from omnisafe.adapter.offpolicy_latent_adapter import OffPolicyLatentAdapter from omnisafe.adapter.online_adapter import OnlineAdapter from omnisafe.adapter.onpolicy_adapter import OnPolicyAdapter from omnisafe.adapter.saute_adapter import SauteAdapter diff --git a/omnisafe/adapter/offpolicy_latent_adapter.py b/omnisafe/adapter/offpolicy_latent_adapter.py index 105ee8ad1..8db394abc 100644 --- a/omnisafe/adapter/offpolicy_latent_adapter.py +++ b/omnisafe/adapter/offpolicy_latent_adapter.py @@ -42,6 +42,18 @@ class OffPolicyLatentAdapter(OnlineAdapter): + """OffPolicy Adapter on Latent Space for OmniSafe. + + :class:`OffPolicyLatentAdapter` is used to adapt the vision-based environment to the off-policy + training. + + Args: + env_id (str): The environment id. + num_envs (int): The number of environments. + seed (int): The random seed. + cfgs (Config): The configuration. + """ + _current_obs: torch.Tensor _ep_ret: torch.Tensor _ep_cost: torch.Tensor @@ -54,8 +66,9 @@ def __init__( # pylint: disable=too-many-arguments seed: int, cfgs: Config, ) -> None: - """Initialize a instance of :class:`OffPolicyAdapter`.""" + """Initialize a instance of :class:`OffPolicyLatentAdapter`.""" super().__init__(env_id, num_envs, seed, cfgs) + assert self.action_space.shape self._observation_concator: ObservationConcator = ObservationConcator( self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2, self.action_space.shape, @@ -65,8 +78,9 @@ def __init__( # pylint: disable=too-many-arguments self._current_obs, _ = self.reset() self._max_ep_len: int = 1000 self._reset_log() - self.z1 = None - self.z2 = None + self.z1: torch.Tensor = torch.zeros(1) + self.z2: torch.Tensor = torch.zeros(1) + self._initialized: bool = False self._reset_sequence_queue = False def _wrapper( @@ -135,6 +149,13 @@ def eval_policy( # pylint: disable=too-many-locals agent: ConstraintActorQCritic, logger: Logger, ) -> None: + """Rollout the environment with deterministic agent action. + + Args: + episode (int): Number of episodes. + agent (ConstraintActorCritic): Agent. + logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. + """ for _ in range(episode): ep_ret, ep_cost, ep_len = 0.0, 0.0, 0 obs, _ = self._eval_env.reset() @@ -161,22 +182,37 @@ def eval_policy( # pylint: disable=too-many-locals }, ) - def pre_process(self, latent_model, concated_obs): + def pre_process( + self, + latent_model: CostLatentModel, + concated_obs: ObservationConcator, + ) -> torch.Tensor: + """Processes the concatenated observations to produce latent representation. + + Args: + latent_model (CostLatentModel): The latent model containing the encoder and decoder. + concated_obs (ObservationConcator): An object that encapsulates the concatenated observations. + + Returns: + A tensor combining the latent variables z1 and z2, representing the current state of + the system in the latent space. + """ with torch.no_grad(): feature = latent_model.encoder(concated_obs.last_state) - if self.z2 is None: + if not self._initialized: z1_mean, z1_std = latent_model.z1_posterior_init(feature) self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std z2_mean, z2_std = latent_model.z2_posterior_init(self.z1) self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std + self._initialized = True else: z1_mean, z1_std = latent_model.z1_posterior( - torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1), ) self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std z2_mean, z2_std = latent_model.z2_posterior( - torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1), ) self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std @@ -191,6 +227,17 @@ def rollout( # pylint: disable=too-many-locals logger: Logger, use_rand_action: bool, ) -> None: + """Rollout the environment and store the data in the buffer. + + Args: + rollout_step (int): Number of rollout steps. + agent (ConstraintActorCritic): Constraint actor-critic, including actor, reward critic, + and cost critic. + latent_model (CostLatentModel): Latent model, including encoder and decoder. + buffer (VectorOnPolicyBuffer): Vector on-policy buffer. + logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. + use_rand_action (bool): Whether to use random action. + """ for step in range(rollout_step): if not self._reset_sequence_queue: buffer.reset_sequence_queue(self._current_obs) @@ -198,10 +245,11 @@ def rollout( # pylint: disable=too-many-locals self._reset_sequence_queue = True if use_rand_action: - act = act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore + act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore else: act = agent.step( - self.pre_process(latent_model, self._observation_concator), deterministic=False + self.pre_process(latent_model, self._observation_concator), + deterministic=False, ) next_obs, reward, cost, terminated, truncated, info = self.step(act) @@ -217,8 +265,9 @@ def rollout( # pylint: disable=too-many-locals if done: self._log_metrics(logger, idx) self._reset_log(idx) - self.z1 = None - self.z2 = None + self.z1 = torch.zeros(1) + self.z2 = torch.zeros(1) + self._initialized = False self._reset_sequence_queue = False if 'final_observation' in info: real_next_obs[idx] = info['final_observation'][idx] @@ -239,11 +288,30 @@ def _log_value( cost: torch.Tensor, info: dict[str, Any], ) -> None: + """Log value. + + .. note:: + OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will + be stored in ``info['original_reward']`` and ``info['original_cost']``. + + Args: + reward (torch.Tensor): The immediate step reward. + cost (torch.Tensor): The immediate step cost. + info (dict[str, Any]): Some information logged by the environment. + """ self._ep_ret += info.get('original_reward', reward).cpu() self._ep_cost += info.get('original_cost', cost).cpu() self._ep_len += info.get('num_step', 1) def _log_metrics(self, logger: Logger, idx: int) -> None: + """Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``. + + Args: + logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. + idx (int): The index of the environment. + """ + if hasattr(self._env, 'spec_log'): + self._env.spec_log(logger) logger.store( { 'Metrics/EpRet': self._ep_ret[idx], @@ -253,6 +321,12 @@ def _log_metrics(self, logger: Logger, idx: int) -> None: ) def _reset_log(self, idx: int | None = None) -> None: + """Reset the episode return, episode cost and episode length. + + Args: + idx (int or None, optional): The index of the environment. Defaults to None + (single environment). + """ if idx is None: self._ep_ret = torch.zeros(self._env.num_envs) self._ep_cost = torch.zeros(self._env.num_envs) @@ -267,6 +341,16 @@ def reset( seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: + """Reset the environment and returns an initial observation. + + Args: + seed (int, optional): The random seed. Defaults to None. + options (dict[str, Any], optional): The options for the environment. Defaults to None. + + Returns: + observation: The initial observation of the space. + info: Some information logged by the environment. + """ obs, info = self._env.reset(seed=seed, options=options) self._observation_concator.reset_episode(obs) return obs, info diff --git a/omnisafe/algorithms/off_policy/safe_slac.py b/omnisafe/algorithms/off_policy/safe_slac.py index 61b9522f2..f8efad839 100644 --- a/omnisafe/algorithms/off_policy/safe_slac.py +++ b/omnisafe/algorithms/off_policy/safe_slac.py @@ -1,4 +1,4 @@ -# Copyright 2023 OmniSafe Team. All Rights Reserved. +# Copyright 2024 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +36,16 @@ @registry.register # pylint: disable-next=too-many-instance-attributes, too-few-public-methods class SafeSLAC(SACLag): + """Safe SLAC algorithms for vision-based safe RL tasks. + + References: + - Title: Safe Reinforcement Learning From Pixels Using a Stochastic Latent Representation. + - Authors: Yannick Hogewind, Thiago D. Simão, Tal Kachman, Nils Jansen. + - URL: `Safe SLAC `_ + """ + + _is_latent_model_init_learned: bool + def _init(self) -> None: if self._cfgs.algo_cfgs.auto_alpha: self._target_entropy = -torch.prod(torch.Tensor(self._env.action_space.shape)).item() @@ -53,7 +63,7 @@ def _init(self) -> None: self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs) - self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( + self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( # type: ignore obs_space=self._env.observation_space, act_space=self._env.action_space, size=self._cfgs.algo_cfgs.size, @@ -64,7 +74,7 @@ def _init(self) -> None: self._is_latent_model_init_learned = False def _init_env(self) -> None: - self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( + self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( # type: ignore self._env_id, self._cfgs.train_cfgs.vector_env_nums, self._seed, @@ -96,6 +106,9 @@ def _init_env(self) -> None: def _init_model(self) -> None: self._cfgs.model_cfgs.critic['num_critics'] = 2 + assert self._env.observation_space.shape + assert self._env.action_space.shape + self._latent_model = CostLatentModel( obs_shape=self._env.observation_space.shape, act_shape=self._env.action_space.shape, @@ -114,9 +127,6 @@ def _init_model(self) -> None: epochs=self._epochs, ).to(self._device) - self._actor_critic = torch.compile(self._actor_critic) - self._latent_model = torch.compile(self._latent_model) - self._latent_model_optimizer = optim.Adam( self._latent_model.parameters(), lr=1e-4, @@ -218,7 +228,7 @@ def learn(self) -> tuple[float, float, float]: return ep_ret, ep_cost, ep_len - def _prepare_batch(self, obs_, action_): + def _prepare_batch(self, obs_: torch.Tensor, action_: torch.Tensor) -> tuple[torch.Tensor, ...]: with torch.no_grad(): feature_ = self._latent_model.encoder(obs_) z_ = torch.cat(self._latent_model.sample_posterior(feature_, action_)[2:4], dim=-1) @@ -266,9 +276,7 @@ def _update(self) -> None: self._update_actor(obs) self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak) - def _update_latent_model( - self, - ): + def _update_latent_model(self) -> None: data = self._buf.sample_batch(32) obs, act, reward, cost, done = ( data['obs'], @@ -280,7 +288,11 @@ def _update_latent_model( self._update_latent_count += 1 loss_kld, loss_image, loss_reward, loss_cost = self._latent_model.calculate_loss( - obs, act, reward, done, cost + obs, + act, + reward, + done, + cost, ) self._latent_model_optimizer.zero_grad() diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index 0e521eacf..a44afecdf 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -135,7 +135,19 @@ def store(self, **data: torch.Tensor) -> None: """ -class BaseSequenceBuffer(BaseBuffer): +class BaseSequenceBuffer(ABC): + r"""Abstract base class for sequence buffer. + + Attributes: + sequence_queue (SequenceQueue): The queue for storing the data. + + Args: + obs_space (OmnisafeSpace): The observation space. + act_space (OmnisafeSpace): The action space. + size (int): The size of the buffer. + device (torch.device): The device of the buffer. Defaults to ``torch.device('cpu')``. + """ + def __init__( self, obs_space: OmnisafeSpace, @@ -178,8 +190,37 @@ def __init__( self._observation_shape = obs_space.shape def add_field(self, name: str, shape: tuple[int, ...], dtype: torch.dtype) -> None: + """Add a field to the buffer. + + Args: + name (str): The name of the field. + shape (tuple of int): The shape of the field. + dtype (torch.dtype): The dtype of the field. + """ self.data[name] = torch.zeros( (self._size, self._num_sequences, *shape), dtype=dtype, device=self._device, ) + + @property + def device(self) -> torch.device: + """The device of the buffer.""" + return self._device + + @property + def size(self) -> int: + """The size of the buffer.""" + return self._size + + def __len__(self) -> int: + """Return the length of the buffer.""" + return self._size + + @abstractmethod + def store(self, **data: torch.Tensor) -> None: + """Store a transition in the buffer. + + Args: + data (torch.Tensor): The data to store. + """ diff --git a/omnisafe/common/buffer/offpolicy_buffer.py b/omnisafe/common/buffer/offpolicy_buffer.py index 2dd3468b7..a7976c7c9 100644 --- a/omnisafe/common/buffer/offpolicy_buffer.py +++ b/omnisafe/common/buffer/offpolicy_buffer.py @@ -123,6 +123,8 @@ def sample_batch(self) -> dict[str, torch.Tensor]: class OffPolicySequenceBuffer(BaseSequenceBuffer): + """Sequence-based Replay buffer for off-policy algorithms.""" + def __init__( # pylint: disable=too-many-arguments self, obs_space: OmnisafeSpace, diff --git a/omnisafe/common/latent.py b/omnisafe/common/latent.py index 18e8ee921..8c9ac83e8 100644 --- a/omnisafe/common/latent.py +++ b/omnisafe/common/latent.py @@ -1,4 +1,24 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the Latent Model for Safe SLAC.""" + + +from __future__ import annotations + import math +from typing import Any import numpy as np import torch @@ -6,22 +26,48 @@ from torch.nn import functional as F -def calculate_kl_divergence(p_mean, p_std, q_mean, q_std): +def calculate_kl_divergence( + p_mean: torch.Tensor, + p_std: torch.Tensor, + q_mean: torch.Tensor, + q_std: torch.Tensor, +) -> torch.Tensor: + """Calculate the KL divergence between two normal distributions. + + Args: + p_mean (torch.Tensor): Mean of the first normal distribution. + p_std (torch.Tensor): Standard deviation of the first normal distribution. + q_mean (torch.Tensor): Mean of the second normal distribution. + q_std (torch.Tensor): Standard deviation of the second normal distribution. + + Returns: + torch.Tensor: The KL divergence between the two distributions. + """ var_ratio = (p_std / q_std).pow_(2) t1 = ((p_mean - q_mean) / q_std).pow_(2) return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) def build_mlp( - input_dim, - output_dim, - hidden_sizes=None, - hidden_activation=nn.Tanh(), - output_activation=None, -): + input_dim: int, + output_dim: int, + hidden_activation: nn.Module, + hidden_sizes: list[int] | None = None, +) -> nn.Sequential: + """Build a multi-layer perceptron (MLP) model using PyTorch. + + Args: + input_dim (int): Dimension of the input features. + output_dim (int): Dimension of the output. + hidden_sizes (list[int], optional): List of integers defining the number of units in each hidden layer. + hidden_activation (nn.Module): Activation function to use after each hidden layer. + + Returns: + nn.Sequential: The constructed MLP model. + """ if hidden_sizes is None: hidden_sizes = [64, 64] - layers = [] + layers: list[Any] = [] units = input_dim for next_units in hidden_sizes: layers.append(nn.Linear(units, next_units)) @@ -29,238 +75,335 @@ def build_mlp( units = next_units model = nn.Sequential(*layers) model.add_module('last_linear', nn.Linear(units, output_dim)) - if output_activation is not None: - model.add_module('output_activation', output_activation) return model -def initialize_weight(m): +def initialize_weight(m: nn.Module) -> None: + """Initializes the weights of the module using Xavier uniform initialization. + + Args: + m (nn.Module): The module whose weights need to be initialized. + """ if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): nn.init.xavier_uniform_(m.weight, gain=1.0) if m.bias is not None: nn.init.constant_(m.bias, 0) -class FixedGaussian(torch.nn.Module): - """ - Fixed diagonal gaussian distribution. +class FixedGaussian(nn.Module): + """Represents a fixed diagonal Gaussian distribution. + + Attributes: + output_dim (int): Dimension of the output distribution. + std (float): Standard deviation of the Gaussian distribution. """ - def __init__(self, output_dim, std) -> None: + def __init__(self, output_dim: int, std: float) -> None: + """Initialize an instance of the Fixed Gaussian.""" super().__init__() self.output_dim = output_dim self.std = std - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple: + """Generates a mean and standard deviation tensor based on the fixed parameters. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: Mean and standard deviation tensors. + """ mean = torch.zeros(x.size(0), self.output_dim, device=x.device) - std = torch.ones(x.size(0), self.output_dim, device=x.device).mul_(self.std) + std = torch.ones(x.size(0), self.output_dim, device=x.device) * self.std return mean, std -class Gaussian(torch.nn.Module): - """ - Diagonal gaussian distribution with state dependent variances. +class Gaussian(nn.Module): + """Represents a diagonal Gaussian distribution with state dependent variances. + + Attributes: + net (nn.Module): Neural network module to compute means and log standard deviations. """ - def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_sizes: list[int], + ) -> None: + """Initialize an instance of the Gaussian distribution.""" super().__init__() self.net = build_mlp( input_dim=input_dim, output_dim=2 * output_dim, - hidden_sizes=hidden_sizes, hidden_activation=nn.ELU(), + hidden_sizes=hidden_sizes, ).apply(initialize_weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple: + """Computes the mean and standard deviation for the Gaussian distribution. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: Mean and log standard deviation tensors. + """ if x.ndim == 3: - B, S, _ = x.size() - x = self.net(x.view(B * S, _)).view(B, S, -1) + batch_size, seq_length, _ = x.size() + x = self.net(x.view(batch_size * seq_length, _)).view(batch_size, seq_length, -1) else: x = self.net(x) mean, std = torch.chunk(x, 2, dim=-1) - std = F.softplus(std) + 1e-5 + std = F.softplus(std) + 1e-5 # pylint: disable=not-callable return mean, std -class Bernoulli(torch.nn.Module): - """ - Diagonal gaussian distribution with state dependent variances. +class Bernoulli(nn.Module): + """A module representing a Bernoulli distribution. + + This class builds a multi-layer perceptron (MLP) that outputs parameters for a Bernoulli + distribution. The output of the MLP is transformed using a sigmoid function to ensure it lies + between 0 and 1, representing probabilities. + + Attributes: + net (nn.Module): The neural network that builds the Bernoulli distribution. + + Args: + input_dim (int): The number of input features to the MLP. + output_dim (int): The number of output features from the MLP. + hidden_sizes (list[int, int]): The sizes of the hidden layers in the MLP. Defaults to + (256, 256). """ - def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_sizes: list[int], + ) -> None: + """Initializes the Bernoulli module with the specified architecture.""" super().__init__() self.net = build_mlp( input_dim=input_dim, output_dim=output_dim, - hidden_sizes=hidden_sizes, hidden_activation=nn.ELU(), + hidden_sizes=hidden_sizes, ).apply(initialize_weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Defines the forward pass of the Bernoulli module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor representing probabilities after applying the sigmoid function. + """ if x.ndim == 3: - B, S, _ = x.size() - x = self.net(x.view(B * S, _)).view(B, S, -1) + batch_size, seq_length, _ = x.size() + x = self.net(x.view(batch_size * seq_length, -1)).view(batch_size, seq_length, -1) else: x = self.net(x) return torch.sigmoid(x) class Decoder(torch.nn.Module): - """ - Decoder. + """The image processing decoder module. + + This decoder module that takes in a latent vector and outputs reconstructed images, + also outputs a tensor of the same shape with a constant standard deviation value. + + Attributes: + net (torch.nn.Sequential): The neural network layers. + std (float): The standard deviation value to be used for the output tensor. + + Args: + input_dim (int): Dimension of the input latent vector. Defaults to 288. + output_dim (int): The number of output channels. Defaults to 3. + std (float): Standard deviation for the generated tensor. Defaults to 1.0. """ - def __init__(self, input_dim=288, output_dim=3, std=1.0) -> None: + def __init__(self, input_dim: int = 288, output_dim: int = 3, std: float = 1.0) -> None: + """Initializes the decoder module.""" super().__init__() self.net = nn.Sequential( - # (32+256, 1, 1) -> (256, 4, 4) nn.ConvTranspose2d(input_dim, 256, 4), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (256, 4, 4) -> (128, 8, 8) nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (128, 8, 8) -> (64, 16, 16) nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (64, 16, 16) -> (32, 32, 32) nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (32, 32, 32) -> (3, 64, 64) nn.ConvTranspose2d(32, output_dim, 5, 2, 2, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), ).apply(initialize_weight) self.std = std - def forward(self, x): - B, S, latent_dim = x.size() - x = x.view(B * S, latent_dim, 1, 1) + def forward(self, x: torch.Tensor) -> tuple: + """Forward pass for generating output image tensor and a tensor filled with std. + + Args: + x (torch.Tensor): Input latent vector tensor. + + Returns: + tuple: A tuple containing the reconstructed image tensor and a tensor filled with std. + """ + batch_size, seq_length, latent_dim = x.size() + x = x.view(batch_size * seq_length, latent_dim, 1, 1) x = self.net(x) - _, C, W, H = x.size() - x = x.view(B, S, C, W, H) + _, channels, width, height = x.size() + x = x.view(batch_size, seq_length, channels, width, height) return x, torch.ones_like(x).mul_(self.std) class Encoder(torch.nn.Module): - """ - Encoder. + """An encoder module that takes in images and outputs a latent vector representation. + + Attributes: + net (torch.nn.Sequential): The neural network layers. + + Args: + input_dim (int): Number of input channels in the image. Defaults to 3. + output_dim (int): Dimension of the output latent vector. Defaults to 256. """ - def __init__(self, input_dim=3, output_dim=256) -> None: + def __init__(self, input_dim: int = 3, output_dim: int = 256) -> None: + """Initialize the Encoder module.""" super().__init__() self.net = nn.Sequential( - # (3, 64, 64) -> (32, 32, 32) nn.Conv2d(input_dim, 32, 5, 2, 2), nn.ELU(inplace=True), - # (32, 32, 32) -> (64, 16, 16) nn.Conv2d(32, 64, 3, 2, 1), nn.ELU(inplace=True), - # (64, 16, 16) -> (128, 8, 8) nn.Conv2d(64, 128, 3, 2, 1), nn.ELU(inplace=True), - # (128, 8, 8) -> (256, 4, 4) nn.Conv2d(128, 256, 3, 2, 1), nn.ELU(inplace=True), - # (256, 4, 4) -> (256, 1, 1) nn.Conv2d(256, output_dim, 4), nn.ELU(inplace=True), ).apply(initialize_weight) - def forward(self, x): - B, S, C, H, W = x.size() - x = x.view(B * S, C, H, W) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass to transform input images into a flat latent vector representation. + + Args: + x (torch.Tensor): Input tensor of images. + + Returns: + torch.Tensor: Output tensor of latent vectors. + """ + batch_size, seq_length, channels, height, width = x.size() + x = x.view(batch_size * seq_length, channels, height, width) x = self.net(x) - return x.view(B, S, -1) + return x.view(batch_size, seq_length, -1) -class CostLatentModel(torch.nn.Module): - """ - Stochastic latent variable model to estimate latent dynamics, reward and cost. +# pylint: disable-next=too-many-instance-attributes +class CostLatentModel(nn.Module): + """The latent model for cost prediction. + + Stochastic latent variable model that estimates latent dynamics, rewards, and costs + for a given observation and action space using variational inference techniques. + + Args: + obs_shape (tuple of int): Shape of the observations. + act_shape (tuple of int): Shape of the actions. + feature_dim (int): Dimension of the feature vector from observations. Defaults to 256. + latent_dim_1 (int): Dimension of the first set of latent variables. Defaults to 32. + latent_dim_2 (int): Dimension of the second set of latent variables. Defaults to 256. + hidden_sizes (list of int): Sizes of hidden layers in the networks. + image_noise (float): Standard deviation of noise in image reconstruction. Defaults to 0.1. """ def __init__( self, - obs_shape, - act_shape, - feature_dim=256, - latent_dim_1=32, - latent_dim_2=256, - hidden_sizes=(256, 256), - image_noise=0.1, + obs_shape: tuple[int, ...], + act_shape: tuple[int, ...], + hidden_sizes: list[int], + feature_dim: int = 256, + latent_dim_1: int = 32, + latent_dim_2: int = 256, + image_noise: float = 0.1, ) -> None: + """Initialize the instance of CostLatentModel.""" super().__init__() self.bceloss = torch.nn.BCELoss(reduction='none') - # p(z1(0)) = N(0, I) self.z1_prior_init = FixedGaussian(latent_dim_1, 1.0) - # p(z2(0) | z1(0)) self.z2_prior_init = Gaussian(latent_dim_1, latent_dim_2, hidden_sizes) - # p(z1(t+1) | z2(t), a(t)) self.z1_prior = Gaussian( - latent_dim_2 + act_shape[0], - latent_dim_1, + int(latent_dim_2 + act_shape[0]), + int(latent_dim_1), hidden_sizes, ) - # p(z2(t+1) | z1(t+1), z2(t), a(t)) self.z2_prior = Gaussian( - latent_dim_1 + latent_dim_2 + act_shape[0], - latent_dim_2, + int(latent_dim_1 + latent_dim_2 + act_shape[0]), + int(latent_dim_2), hidden_sizes, ) - # q(z1(0) | feat(0)) self.z1_posterior_init = Gaussian(feature_dim, latent_dim_1, hidden_sizes) - # q(z2(0) | z1(0)) = p(z2(0) | z1(0)) self.z2_posterior_init = self.z2_prior_init - # q(z1(t+1) | feat(t+1), z2(t), a(t)) self.z1_posterior = Gaussian( - feature_dim + latent_dim_2 + act_shape[0], - latent_dim_1, + int(feature_dim + latent_dim_2 + act_shape[0]), + int(latent_dim_1), hidden_sizes, ) - # q(z2(t+1) | z1(t+1), z2(t), a(t)) = p(z2(t+1) | z1(t+1), z2(t), a(t)) self.z2_posterior = self.z2_prior - # p(r(t) | z1(t), z2(t), a(t), z1(t+1), z2(t+1)) self.reward = Gaussian( - 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + int(2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0]), 1, hidden_sizes, ) self.cost = Bernoulli( - 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + int(2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0]), 1, hidden_sizes, ) - # feat(t) = Encoder(x(t)) self.encoder = Encoder(obs_shape[0], feature_dim) - # p(x(t) | z1(t), z2(t)) self.decoder = Decoder( - latent_dim_1 + latent_dim_2, - obs_shape[0], + int(latent_dim_1 + latent_dim_2), + int(obs_shape[0]), std=np.sqrt(image_noise), ) self.apply(initialize_weight) - def sample_prior(self, actions_, z2_post_): - # p(z1(0)) = N(0, I) + def sample_prior(self, actions_: torch.Tensor, z2_post_: torch.Tensor) -> tuple: + """Sample the prior distribution for latent variables using initial and recurrent models. + + Args: + actions_ (torch.Tensor): The actions taken at each step of the sequence. + z2_post_ (torch.Tensor): The posterior samples of the second set of latent variables. + + Returns: + tuple: A tuple containing means and standard deviations for the first set of latent variables. + """ z1_mean_init, z1_std_init = self.z1_prior_init(actions_[:, 0]) - # p(z1(t) | z2(t-1), a(t-1)) z1_mean_, z1_std_ = self.z1_prior( - torch.cat([z2_post_[:, : actions_.size(1)], actions_], dim=-1) + torch.cat([z2_post_[:, : actions_.size(1)], actions_], dim=-1), ) - # Concatenate initial and consecutive latent variables z1_mean_ = torch.cat([z1_mean_init.unsqueeze(1), z1_mean_], dim=1) z1_std_ = torch.cat([z1_std_init.unsqueeze(1), z1_std_], dim=1) return (z1_mean_, z1_std_) - def sample_posterior(self, features_, actions_): - # p(z1(0)) = N(0, I) + def sample_posterior(self, features_: torch.Tensor, actions_: torch.Tensor) -> tuple: + """Sample the posterior distribution for latent variables. + + Args: + features_ (torch.Tensor): The features extracted from observations at each timestep. + actions_ (torch.Tensor): The actions taken at each step of the sequence. + + Returns: + tuple: A tuple of tensors containing means, standard deviations, and samples of the latent variables. + """ z1_mean, z1_std = self.z1_posterior_init(features_[:, 0]) z1 = z1_mean + torch.randn_like(z1_std) * z1_std - # p(z2(0) | z1(0)) z2_mean, z2_std = self.z2_posterior_init(z1) z2 = z2_mean + torch.randn_like(z2_std) * z2_std @@ -270,12 +413,10 @@ def sample_posterior(self, features_, actions_): z2_ = [z2] for t in range(1, actions_.size(1) + 1): - # q(z1(t) | feat(t), z2(t-1), a(t-1)) z1_mean, z1_std = self.z1_posterior( - torch.cat([features_[:, t], z2, actions_[:, t - 1]], dim=1) + torch.cat([features_[:, t], z2, actions_[:, t - 1]], dim=1), ) z1 = z1_mean + torch.randn_like(z1_std) * z1_std - # q(z2(t) | z1(t), z2(t-1), a(t-1)) z2_mean, z2_std = self.z2_posterior(torch.cat([z1, z2, actions_[:, t - 1]], dim=1)) z2 = z2_mean + torch.randn_like(z2_std) * z2_std @@ -290,44 +431,63 @@ def sample_posterior(self, features_, actions_): z2_ = torch.stack(z2_, dim=1) return (z1_mean_, z1_std_, z1_, z2_) - # - def calculate_loss(self, state_, action_, reward_, done_, cost_): - # Calculate the sequence of features. - feature_ = self.encoder(state_) + # pylint: disable-next=too-many-locals + def calculate_loss( + self, + state_: torch.Tensor, + action_: torch.Tensor, + reward_: torch.Tensor, + done_: torch.Tensor, + cost_: torch.Tensor, + ) -> tuple: + """Calculate the loss for the model. + + Args: + state_ (torch.Tensor): Observed states over a sequence of timesteps. + action_ (torch.Tensor): Actions taken at each timestep. + reward_ (torch.Tensor): Observed rewards at each timestep. + done_ (torch.Tensor): Done flags indicating the end of an episode. + cost_ (torch.Tensor): Observed costs at each timestep. + + Returns: + tuple: A tuple containing the KL divergence loss, image reconstruction loss, + reward prediction loss, and cost classification loss. + """ + feature_ = self.forward(state_) - # Sample from latent variable model. z1_mean_post_, z1_std_post_, z1_, z2_ = self.sample_posterior(feature_, action_) z1_mean_pri_, z1_std_pri_ = self.sample_prior(action_, z2_) - # Calculate KL divergence loss. loss_kld = ( calculate_kl_divergence(z1_mean_post_, z1_std_post_, z1_mean_pri_, z1_std_pri_) .mean(dim=0) .sum() ) - # Prediction loss of images. z_ = torch.cat([z1_, z2_], dim=-1) state_mean_, state_std_ = self.decoder(z_) state_noise_ = (state_ - state_mean_) / (state_std_ + 1e-8) log_likelihood_ = (-0.5 * state_noise_.pow(2) - state_std_.log()) - 0.5 * math.log( - 2 * math.pi + 2 * math.pi, ) loss_image = -log_likelihood_.mean(dim=0).sum() - # Prediction loss of rewards. x = torch.cat([z_[:, :-1], action_, z_[:, 1:]], dim=-1) - B, S, X = x.shape - reward_mean_, reward_std_ = self.reward(x.view(B * S, X)) - reward_mean_ = reward_mean_.view(B, S, 1) - reward_std_ = reward_std_.view(B, S, 1) + batch_size, seq_length, concated_shape = x.shape + reward_mean_, reward_std_ = self.reward(x.view(batch_size * seq_length, concated_shape)) + reward_mean_ = reward_mean_.view(batch_size, seq_length, 1) + reward_std_ = reward_std_.view(batch_size, seq_length, 1) reward_noise_ = (reward_ - reward_mean_) / (reward_std_ + 1e-8) log_likelihood_reward_ = (-0.5 * reward_noise_.pow(2) - reward_std_.log()) - 0.5 * math.log( - 2 * math.pi + 2 * math.pi, ) loss_reward = -log_likelihood_reward_.mul_(1 - done_).mean(dim=0).sum() - p = self.cost(x.view(B * S, X)).view(B, S, 1) + p = self.cost(x.view(batch_size * seq_length, concated_shape)).view( + batch_size, + seq_length, + 1, + ) q = 1 - p weight_p = 100 binary_cost_ = torch.sign(cost_) @@ -342,3 +502,7 @@ def calculate_loss(self, state_, action_, reward_, done_, cost_): ) return loss_kld, loss_image, loss_reward, loss_cost + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """Get the encoded state of observation.""" + return self.encoder(obs) diff --git a/omnisafe/common/normalizer.py b/omnisafe/common/normalizer.py index 38bbe5f77..e90909089 100644 --- a/omnisafe/common/normalizer.py +++ b/omnisafe/common/normalizer.py @@ -144,12 +144,12 @@ def load_state_dict( strict: bool = True, assign: bool = False, ) -> Any: - """Load the state_dict to the normalizer. + """Load the parameters to the normalizer. Args: state_dict (Mapping[str, Any]): The state_dict to be loaded. - strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`. - Defaults to True. + strict (bool, optional): Whether to strictly enforce that the keys in + :attr:`state_dict`. Defaults to True. Returns: The loaded normalizer. diff --git a/omnisafe/configs/off-policy/SafeSLAC.yaml b/omnisafe/configs/off-policy/SafeSLAC.yaml index a11626fe3..48df653f4 100644 --- a/omnisafe/configs/off-policy/SafeSLAC.yaml +++ b/omnisafe/configs/off-policy/SafeSLAC.yaml @@ -42,16 +42,14 @@ defaults: image_noise: 0.4 # list of sizes for hidden layers of dynamics model hidden_sizes: [256, 256] - # dimensionality of feature vectors initially - feature_dim: 256 - # dimensionality of the first latent space + # dimension of the first latent space latent_dim_1: 32 - # dimensionality of the second latent space + # dimension of the second latent space latent_dim_2: 200 - # dimensionality of feature vectors + # dimension of feature vectors feature_dim: 200 # number of steps to update the policy - steps_per_epoch: 2000 + steps_per_epoch: 2000 # number of steps per sample update_cycle: 100 # number of iterations to update the policy diff --git a/omnisafe/envs/safety_gymnasium_vision_env.py b/omnisafe/envs/safety_gymnasium_vision_env.py index 778ddacb4..4ab6cc48e 100644 --- a/omnisafe/envs/safety_gymnasium_vision_env.py +++ b/omnisafe/envs/safety_gymnasium_vision_env.py @@ -29,6 +29,27 @@ @env_register class SafetyGymnasiumVisionEnv(CMDP): + """Safety Gymnasium Vision-Based Environment. + + Args: + env_id (str): Environment id. + num_envs (int, optional): Number of environments. Defaults to 1. + device (torch.device, optional): Device to store the data. Defaults to + ``torch.device('cpu')``. + + Keyword Args: + render_mode (str, optional): The render mode ranges from 'human' to 'rgb_array' and 'rgb_array_list'. + Defaults to 'rgb_array'. + camera_name (str, optional): The camera name. + camera_id (int, optional): The camera id. + width (int, optional): The width of the rendered image. Defaults to 256. + height (int, optional): The height of the rendered image. Defaults to 256. + + Attributes: + need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. + need_time_limit_wrapper (bool): Whether to use time limit wrapper. + """ + need_auto_reset_wrapper: bool = False need_time_limit_wrapper: bool = False diff --git a/omnisafe/utils/model.py b/omnisafe/utils/model.py index 5e1b44a0c..cdc52b1b2 100644 --- a/omnisafe/utils/model.py +++ b/omnisafe/utils/model.py @@ -115,13 +115,36 @@ def build_mlp_network( class ObservationConcator: - def __init__(self, state_shape, action_shape, num_sequences, device=DEVICE_CPU) -> None: + """A class designed to concatenate observations and actions over a specified time steps.""" + + def __init__( + self, + state_shape: tuple, + action_shape: tuple, + num_sequences: int, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize the ObservationConcator with given shapes and device. + + Args: + state_shape (tuple): Shape of the state space. + action_shape (tuple): Shape of the action space. + num_sequences (int): Number of sequences to maintain in the history. + device (str): The device (CPU/GPU) on which to create tensors. + """ self.state_shape = state_shape self.action_shape = action_shape self.num_sequences = num_sequences self.device = device + self._state: deque = deque(maxlen=self.num_sequences) + self._action: deque = deque(maxlen=self.num_sequences - 1) - def reset_episode(self, state): + def reset_episode(self, state: torch.Tensor) -> None: + """Reset the history of states and actions for a new episode. + + Args: + state (torch.Tensor): The initial state for the new episode. + """ self._state = deque(maxlen=self.num_sequences) self._action = deque(maxlen=self.num_sequences - 1) for _ in range(self.num_sequences - 1): @@ -133,22 +156,30 @@ def reset_episode(self, state): ) self._state.append(state) - def append(self, state, action): + def append(self, state: torch.Tensor, action: torch.Tensor) -> None: + """Append a new state and action to the queue. + + Args: + state (torch.Tensor): State to be appended. + action (torch.Tensor): Action to be appended. + """ self._state.append(state) self._action.append(action) @property - def state(self): - return self._state[None, ...] + def last_state(self) -> torch.Tensor: + """Returns the most recent state. - @property - def last_state(self): + Returns: + torch.Tensor: The most recent state. + """ return self._state[-1][None, ...] @property - def action(self): - return self._action.reshape(1, -1) + def last_action(self) -> torch.Tensor: + """Returns the most recent action. - @property - def last_action(self): + Returns: + torch.Tensor: The most recent action. + """ return self._action[-1] diff --git a/omnisafe/utils/tools.py b/omnisafe/utils/tools.py index ba8e18694..b3e5a9fad 100644 --- a/omnisafe/utils/tools.py +++ b/omnisafe/utils/tools.py @@ -352,81 +352,138 @@ def get_device(device: torch.device | str | int = DEVICE_CPU) -> torch.device: # Force conversion to torch.device device = torch.device(device) - # Cuda not available if not torch.cuda.is_available() and device.type == torch.device('cuda').type: return torch.device('cpu') return device -def create_feature_actions(feature_, action_): - N = feature_.size(0) - # Flatten sequence of features. - # f (batch_size, (num_sequences)*feature_dim) - f = feature_[:, :-1].view(N, -1) - n_f = feature_[:, 1:].view(N, -1) - # Flatten sequence of actions. - a = action_[:, :-1].view(N, -1) - n_a = action_[:, 1:].view(N, -1) - # Concatenate feature and action. - fa = torch.cat([f, a], dim=-1) - n_fa = torch.cat([n_f, n_a], dim=-1) - return fa, n_fa +class LazyFrames: + """A class that lazily handles a list of frames. + Attributes: + _frames (List[Any]): Private storage for frame data, stored as a list. + """ -class LazyFrames: - def __init__(self, frames) -> None: + def __init__(self, frames: list[Any]) -> None: + """Initializes a new instance of LazyFrames. + + Args: + frames (list[Any]): A list of frame data. + """ self._frames = list(frames) - def __array__(self, dtype): + def __array__(self, dtype: type[np.dtype]) -> np.ndarray: + """Returns an array representation of the frames with the specified data type. + + Args: + dtype (Type[np.dtype]): The desired data type for the numpy array. + + Returns: + np.ndarray: An array representation of the frames. + """ return np.array(self._frames, dtype=dtype) def __len__(self) -> int: + """Returns the number of frames. + + Returns: + int: The number of frames. + """ return len(self._frames) class SequenceQueue: - def __init__(self, obs_space: OmnisafeSpace, num_sequences: int = 8, device=DEVICE_CPU) -> None: + """A queue that manages sequences of observations, actions, rewards, etc. for RL agents. + + Attributes: + num_sequences (int): The number of sequences to store. + _reset_episode (bool): Flag to indicate whether the queue needs to be reset. + _obs_space (OmnisafeSpace): The space of the observations. + _device (str): The device (CPU or GPU) on which to process the tensors. + data (dict): A dictionary containing observations, actions, rewards, costs and terminated + or truncated flags. + """ + + def __init__( + self, + obs_space: OmnisafeSpace, + num_sequences: int = 8, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize a new instance of SequenceQueue. + + Args: + obs_space (OmnisafeSpace): The observation space of the environment. + num_sequences (int, optional): Maximum number of sequences to store. Defaults to 8. + device (str, optional): The computation device ('cpu' or 'gpu'). Defaults to 'cpu'. + """ self.num_sequences = num_sequences self._reset_episode = False self._obs_space = obs_space self._device = device - self.data = {} + self.data: dict[str, Any] = {} self.data['obs'] = deque(maxlen=self.num_sequences + 1) self.data['act'] = deque(maxlen=self.num_sequences) self.data['reward'] = deque(maxlen=self.num_sequences) self.data['done'] = deque(maxlen=self.num_sequences) self.data['cost'] = deque(maxlen=self.num_sequences) - def reset_sequence_queue(self, obs): - for k in self.data: - self.data[k].clear() + def reset_sequence_queue(self, obs: torch.Tensor) -> None: + """Reset the sequence queue with the initial observation. + + Args: + obs (torch.Tensor): The initial observation from the environment. + """ + for _, v in self.data.items(): + v.clear() self._reset_episode = True self.data['obs'].append(obs.detach().cpu().numpy()) - def append(self, **data: torch.Tensor): - assert self._reset_episode, self._reset_episode + def append(self, **data: torch.Tensor) -> None: + """Append new data to the queue. + + Args: + data (torch.Tensor): Keyword arguments where keys are data types ('obs', 'act', etc.) + and values are torch Tensors. + """ + assert self._reset_episode, 'Reset the sequence queue before appending data.' for key, value in data.items(): self.data[key].append(value.detach().cpu().numpy()) - def _process_get(self, key: str) -> LazyFrames | torch.Tensor: + def _process_get(self, key: str) -> Any: + """Process the requested data for retrieval. + + Args: + key (str): The key of the data to retrieve. + + Returns: + Union[LazyFrames, torch.Tensor]: The processed data, either as LazyFrames or a torch Tensor. + """ if key == 'obs': return np.array(LazyFrames(self.data['obs']), dtype=np.float32).swapaxes(0, 1).squeeze() - else: - return torch.tensor( - np.array(self.data[key], dtype=np.float32), - dtype=torch.float32, - device=self._device, - ) - - def get(self): + return torch.tensor( + np.array(self.data[key], dtype=np.float32), + dtype=torch.float32, + device=self._device, + ) + + def get(self) -> dict: + """Retrieve all stored data. + + Returns: + dict: A dictionary of processed data for each stored type. + """ return {key: self._process_get(key) for key in self.data} - def is_empty(self): - return len(self.data['reward']) == 0 - - def is_full(self): + def is_full(self) -> bool: + """Return whether the sequence queue is full.""" return len(self.data['reward']) == self.num_sequences def __len__(self) -> int: + """Get the number of rewards stored, which reflects the number of complete sequences. + + Returns: + int: The number of sequences. + """ return len(self.data['reward']) From 277b6e170e9d241ff2c735be33f9b6ac777fcc08 Mon Sep 17 00:00:00 2001 From: Gaiejj Date: Sun, 12 May 2024 22:45:15 +0800 Subject: [PATCH 3/3] feat: support evaluation --- examples/evaluate_saved_policy.py | 3 +- omnisafe/adapter/offpolicy_latent_adapter.py | 77 ++++++++++- omnisafe/algorithms/off_policy/ddpg.py | 58 +++++--- omnisafe/algorithms/off_policy/safe_slac.py | 14 ++ .../control_barrier_function/crabs/models.py | 4 +- omnisafe/configs/off-policy/SafeSLAC.yaml | 2 +- omnisafe/envs/__init__.py | 2 +- omnisafe/envs/safety_gymnasium_vision_env.py | 6 +- omnisafe/evaluator.py | 126 +++++++++++++++++- 9 files changed, 264 insertions(+), 28 deletions(-) diff --git a/examples/evaluate_saved_policy.py b/examples/evaluate_saved_policy.py index e87c314a0..d478383bc 100644 --- a/examples/evaluate_saved_policy.py +++ b/examples/evaluate_saved_policy.py @@ -21,7 +21,7 @@ # Just fill your experiment's log directory in here. # Such as: ~/omnisafe/examples/runs/PPOLag-{SafetyPointGoal1-v0}/seed-000-2023-03-07-20-25-48 -LOG_DIR = '' +LOG_DIR = '/home/jiayi/SLAC/omnisafe_zjy/examples/runs/SafeSLAC-{SafetyPointGoal1Vision-v0}/seed-000-2024-05-12-20-50-23' if __name__ == '__main__': evaluator = omnisafe.Evaluator(render_mode='rgb_array') scan_dir = os.scandir(os.path.join(LOG_DIR, 'torch_save')) @@ -34,6 +34,5 @@ width=256, height=256, ) - evaluator.render(num_episodes=1) evaluator.evaluate(num_episodes=1) scan_dir.close() diff --git a/omnisafe/adapter/offpolicy_latent_adapter.py b/omnisafe/adapter/offpolicy_latent_adapter.py index 8db394abc..37d3b480d 100644 --- a/omnisafe/adapter/offpolicy_latent_adapter.py +++ b/omnisafe/adapter/offpolicy_latent_adapter.py @@ -134,6 +134,36 @@ def _wrapper( if self._env.num_envs == 1: self._env = Unsqueeze(self._env, device=self._device) + def _wrapper_eval( + self, + obs_normalize: bool = True, + ) -> None: + """Wrapper the environment for evaluation. + + Args: + obs_normalize (bool, optional): Whether to normalize the observation. Defaults to True. + reward_normalize (bool, optional): Whether to normalize the reward. Defaults to True. + cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True. + """ + assert self._eval_env, 'Your environment for evaluation does not exist!' + if self._env.need_time_limit_wrapper: + assert ( + self._eval_env.max_episode_steps + ), 'You must define max_episode_steps as an\ + \ninteger or cancel the use of the time_limit wrapper.' + self._eval_env = TimeLimit( + self._eval_env, + time_limit=self._eval_env.max_episode_steps, + device=self._device, + ) + if self._env.need_auto_reset_wrapper: + self._eval_env = AutoReset(self._eval_env, device=self._device) + if obs_normalize: + self._eval_env = ObsNormalize(self._eval_env, device=self._device) + self._eval_env = ActionScale(self._eval_env, low=-1.0, high=1.0, device=self._device) + self._eval_env = ActionRepeat(self._eval_env, times=2, device=self._device) + self._eval_env = Unsqueeze(self._eval_env, device=self._device) + @property def latent_space(self) -> Box: """Get the latent space.""" @@ -147,6 +177,7 @@ def eval_policy( # pylint: disable=too-many-locals self, episode: int, agent: ConstraintActorQCritic, + latent_model: CostLatentModel, logger: Logger, ) -> None: """Rollout the environment with deterministic agent action. @@ -154,24 +185,66 @@ def eval_policy( # pylint: disable=too-many-locals Args: episode (int): Number of episodes. agent (ConstraintActorCritic): Agent. + latent_model (CostLatentModel): Latent model, including encoder and decoder. logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. """ + assert self._eval_env, 'Your environment for evaluation does not exist!' + assert self.action_space.shape + eval_observation_concator: ObservationConcator = ObservationConcator( + self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2, + self.action_space.shape, + self._cfgs.algo_cfgs.num_sequences, + device=self._device, + ) for _ in range(episode): ep_ret, ep_cost, ep_len = 0.0, 0.0, 0 obs, _ = self._eval_env.reset() obs = obs.to(self._device) + eval_observation_concator.reset_episode(obs) + + with torch.no_grad(): + feature = latent_model.encoder(eval_observation_concator.last_state) + + z1_mean, z1_std = latent_model.z1_posterior_init(feature) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = latent_model.z2_posterior_init(z1) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + latent_obs = torch.cat((z1, z2), dim=-1).squeeze() done = False while not done: - act = agent.step(obs, deterministic=True) + act = agent.step(latent_obs, deterministic=True) obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) obs, reward, cost, terminated, truncated = ( torch.as_tensor(x, dtype=torch.float32, device=self._device) for x in (obs, reward, cost, terminated, truncated) ) + + eval_observation_concator.append(obs, act) + + with torch.no_grad(): + feature = latent_model.encoder(eval_observation_concator.last_state) + z1_mean, z1_std = latent_model.z1_posterior( + torch.cat( + [feature.squeeze(), z2.squeeze(), eval_observation_concator.last_action], + dim=-1, + ), + ) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = latent_model.z2_posterior( + torch.cat( + [z1.squeeze(), z2.squeeze(), eval_observation_concator.last_action], + dim=-1, + ), + ) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + latent_obs = torch.cat((z1, z2), dim=-1).squeeze() + ep_ret += info.get('original_reward', reward).cpu() ep_cost += info.get('original_cost', cost).cpu() - ep_len += 1 + ep_len += info.get('num_step', 1) + done = bool(terminated[0].item()) or bool(truncated[0].item()) logger.store( diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py index 8c799e322..0ce31f286 100644 --- a/omnisafe/algorithms/off_policy/ddpg.py +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -188,23 +188,36 @@ def _init_log(self) -> None: config=self._cfgs, ) - what_to_save: dict[str, Any] = {} - what_to_save['pi'] = self._actor_critic.actor - if self._cfgs.algo_cfgs.obs_normalize: - obs_normalizer = self._env.save()['obs_normalizer'] - what_to_save['obs_normalizer'] = obs_normalizer - - self._logger.setup_torch_saver(what_to_save) + self._log_what_to_save() self._logger.torch_save() + self._specific_save() - self._logger.register_key('Metrics/EpRet', window_length=50) - self._logger.register_key('Metrics/EpCost', window_length=50) - self._logger.register_key('Metrics/EpLen', window_length=50) + self._logger.register_key( + 'Metrics/EpRet', + window_length=self._cfgs.logger_cfgs.window_lens, + ) + self._logger.register_key( + 'Metrics/EpCost', + window_length=self._cfgs.logger_cfgs.window_lens, + ) + self._logger.register_key( + 'Metrics/EpLen', + window_length=self._cfgs.logger_cfgs.window_lens, + ) if self._cfgs.train_cfgs.eval_episodes > 0: - self._logger.register_key('Metrics/TestEpRet', window_length=50) - self._logger.register_key('Metrics/TestEpCost', window_length=50) - self._logger.register_key('Metrics/TestEpLen', window_length=50) + self._logger.register_key( + 'Metrics/TestEpRet', + window_length=self._cfgs.logger_cfgs.window_lens, + ) + self._logger.register_key( + 'Metrics/TestEpCost', + window_length=self._cfgs.logger_cfgs.window_lens, + ) + self._logger.register_key( + 'Metrics/TestEpLen', + window_length=self._cfgs.logger_cfgs.window_lens, + ) self._logger.register_key('Train/Epoch') self._logger.register_key('Train/LR') @@ -258,7 +271,7 @@ def learn(self) -> tuple[float, float, float]: for sample_step in range( epoch * self._samples_per_epoch, - (epoch + 1) * self._samples_per_epoch + 1, + (epoch + 1) * self._samples_per_epoch, ): step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums @@ -306,7 +319,7 @@ def learn(self) -> tuple[float, float, float]: self._logger.store( { - 'TotalEnvSteps': step, + 'TotalEnvSteps': step + 1, 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), 'Time/Epoch': (time.time() - epoch_time), @@ -320,6 +333,7 @@ def learn(self) -> tuple[float, float, float]: # save model to disk if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: self._logger.torch_save() + self._specific_save() ep_ret = self._logger.get_stats('Metrics/EpRet')[0] ep_cost = self._logger.get_stats('Metrics/EpCost')[0] @@ -544,3 +558,17 @@ def _log_when_not_update(self) -> None: 'Value/cost_critic': 0.0, }, ) + + def _log_what_to_save(self) -> None: + """Define what need to be saved below.""" + what_to_save: dict[str, Any] = {} + + what_to_save['pi'] = self._actor_critic.actor + if self._cfgs.algo_cfgs.obs_normalize: + obs_normalizer = self._env.save()['obs_normalizer'] + what_to_save['obs_normalizer'] = obs_normalizer + + self._logger.setup_torch_saver(what_to_save) + + def _specific_save(self) -> None: + """Save some algorithms specific models per epoch.""" diff --git a/omnisafe/algorithms/off_policy/safe_slac.py b/omnisafe/algorithms/off_policy/safe_slac.py index f8efad839..989b8f500 100644 --- a/omnisafe/algorithms/off_policy/safe_slac.py +++ b/omnisafe/algorithms/off_policy/safe_slac.py @@ -18,6 +18,7 @@ from __future__ import annotations import time +from typing import Any import torch from rich.progress import track @@ -190,6 +191,7 @@ def learn(self) -> tuple[float, float, float]: self._env.eval_policy( episode=self._cfgs.train_cfgs.eval_episodes, agent=self._actor_critic, + latent_model=self._latent_model, logger=self._logger, ) eval_time = time.time() - eval_start @@ -303,3 +305,15 @@ def _update_latent_model(self) -> None: self._cfgs.algo_cfgs.max_grad_norm, ) self._latent_model_optimizer.step() + + def _log_what_to_save(self) -> None: + """Define what need to be saved below.""" + what_to_save: dict[str, Any] = {} + + what_to_save['pi'] = self._actor_critic.actor + if self._cfgs.algo_cfgs.obs_normalize: + obs_normalizer = self._env.save()['obs_normalizer'] + what_to_save['obs_normalizer'] = obs_normalizer + what_to_save['latent_model'] = self._latent_model + + self._logger.setup_torch_saver(what_to_save) diff --git a/omnisafe/common/control_barrier_function/crabs/models.py b/omnisafe/common/control_barrier_function/crabs/models.py index a03aecb5e..ff98d3274 100644 --- a/omnisafe/common/control_barrier_function/crabs/models.py +++ b/omnisafe/common/control_barrier_function/crabs/models.py @@ -68,7 +68,7 @@ def forward(self, states, actions): zip(self.elites, np.array_split(perm, len(self.elites))), ): next_states.append(self.models[model_idx](states[indices], actions[indices])) - return torch.cat(next_states, dim=0)[inv_perm] + return torch.cat(next_states, dim=0)[inv_perm] # type: ignore def get_nlls(self, states, actions, next_states): """Get the negative log likelihoods. @@ -486,7 +486,7 @@ def forward(self, states: torch.Tensor): if np.min(all_u) <= 0: index = np.min(np.where(all_u <= 0)[0]) - action = actions[index] + action = actions[index] # type: ignore else: action = self.crabs.policy(states[0]) diff --git a/omnisafe/configs/off-policy/SafeSLAC.yaml b/omnisafe/configs/off-policy/SafeSLAC.yaml index 48df653f4..d548fc43b 100644 --- a/omnisafe/configs/off-policy/SafeSLAC.yaml +++ b/omnisafe/configs/off-policy/SafeSLAC.yaml @@ -29,7 +29,7 @@ defaults: # total number of steps to train total_steps: 1000000 # number of evaluate episodes - eval_episodes: 0 + eval_episodes: 1 # algorithm configurations algo_cfgs: # number of times each action is repeated in the environment diff --git a/omnisafe/envs/__init__.py b/omnisafe/envs/__init__.py index 53c488ffd..4d72845a9 100644 --- a/omnisafe/envs/__init__.py +++ b/omnisafe/envs/__init__.py @@ -24,5 +24,5 @@ from omnisafe.envs.mujoco_env import MujocoEnv from omnisafe.envs.safety_gymnasium_env import SafetyGymnasiumEnv from omnisafe.envs.safety_gymnasium_modelbased import SafetyGymnasiumModelBased -from omnisafe.envs.safety_isaac_gym_env import SafetyIsaacGymEnv from omnisafe.envs.safety_gymnasium_vision_env import SafetyGymnasiumVisionEnv +from omnisafe.envs.safety_isaac_gym_env import SafetyIsaacGymEnv diff --git a/omnisafe/envs/safety_gymnasium_vision_env.py b/omnisafe/envs/safety_gymnasium_vision_env.py index 4ab6cc48e..1fb42f256 100644 --- a/omnisafe/envs/safety_gymnasium_vision_env.py +++ b/omnisafe/envs/safety_gymnasium_vision_env.py @@ -85,7 +85,6 @@ def __init__( camera_name='vision', width=64, height=64, - **kwargs, ) self._observation_space = Box(shape=(3, 64, 64), low=0, high=255, dtype=np.uint8) @@ -181,6 +180,11 @@ def reset( info, ) + @property + def max_episode_steps(self) -> int: + """The max steps per episode.""" + return self._env.spec.max_episode_steps + def set_seed(self, seed: int) -> None: """Set the seed for the environment. diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 70352dbb9..d0c2c5766 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -36,12 +36,14 @@ SafeARCPlanner, ) from omnisafe.common import Normalizer +from omnisafe.common.latent import CostLatentModel from omnisafe.envs.core import CMDP, make from omnisafe.envs.wrapper import ActionRepeat, ActionScale, ObsNormalize, TimeLimit from omnisafe.models.actor import ActorBuilder from omnisafe.models.actor_critic import ConstraintActorCritic, ConstraintActorQCritic from omnisafe.models.base import Actor from omnisafe.utils.config import Config +from omnisafe.utils.model import ObservationConcator class Evaluator: # pylint: disable=too-many-instance-attributes @@ -76,6 +78,7 @@ def __init__( self._actor: Actor | None = actor self._actor_critic: ConstraintActorCritic | ConstraintActorQCritic | None = actor_critic self._dynamics: EnsembleDynamicsModel | None = dynamics + self._latent_model: CostLatentModel | None = None self._planner = planner self._dividing_line: str = '\n' + '#' * 50 + '\n' @@ -278,6 +281,26 @@ def __load_model_and_env( high=np.hstack((observation_space.high, np.inf)), shape=(observation_space.shape[0] + 1,), ) + if self._cfgs['algo'] == 'SafeSLAC': + self._latent_model = CostLatentModel( + obs_shape=observation_space.shape, + act_shape=action_space.shape, + feature_dim=self._cfgs['algo_cfgs']['feature_dim'], + latent_dim_1=self._cfgs['algo_cfgs']['latent_dim_1'], + latent_dim_2=self._cfgs['algo_cfgs']['latent_dim_2'], + hidden_sizes=self._cfgs['algo_cfgs']['hidden_sizes'], + image_noise=self._cfgs['algo_cfgs']['image_noise'], + ) + self._latent_model.load_state_dict(model_params['latent_model']) + observation_space = Box( + low=-np.inf, + high=np.inf, + shape=( + self._cfgs['algo_cfgs']['latent_dim_1'] + + self._cfgs['algo_cfgs']['latent_dim_2'], + ), + ) + actor_type = self._cfgs['model_cfgs']['actor_type'] pi_cfg = self._cfgs['model_cfgs']['actor'] weight_initialization_mode = self._cfgs['model_cfgs']['weight_initialization_mode'] @@ -336,6 +359,7 @@ def load_saved( self.__load_model_and_env(save_dir, model_name, env_kwargs) + # pylint: disable-next=too-many-locals def evaluate( self, num_episodes: int = 10, @@ -358,6 +382,13 @@ def evaluate( 'The environment and the policy must be provided or created before evaluating the agent.', ) + if self._cfgs['algo'] == 'SafeSLAC': + assert self._env.action_space.shape + eval_observation_concator: ObservationConcator = ObservationConcator( + self._cfgs['algo_cfgs']['latent_dim_1'] + self._cfgs['algo_cfgs']['latent_dim_2'], + self._env.action_space.shape, + self._cfgs['algo_cfgs']['num_sequences'], + ) episode_rewards: list[float] = [] episode_costs: list[float] = [] episode_lengths: list[float] = [] @@ -367,6 +398,18 @@ def evaluate( self._safety_obs = torch.ones(1) ep_ret, ep_cost, length = 0.0, 0.0, 0.0 + if self._cfgs['algo'] == 'SafeSLAC': + assert self._latent_model, 'The latent model must be provided.' + obs = obs.unsqueeze(0) + eval_observation_concator.reset_episode(obs) + with torch.no_grad(): + feature = self._latent_model.encoder(eval_observation_concator.last_state) + z1_mean, z1_std = self._latent_model.z1_posterior_init(feature) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = self._latent_model.z2_posterior_init(z1) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + obs = torch.cat((z1, z2), dim=-1).squeeze() + done = False while not done: if 'Saute' in self._cfgs['algo'] or 'Simmer' in self._cfgs['algo']: @@ -387,7 +430,34 @@ def evaluate( raise ValueError( 'The policy must be provided or created before evaluating the agent.', ) - obs, rew, cost, terminated, truncated, _ = self._env.step(act) + obs, rew, cost, terminated, truncated, info = self._env.step(act) + if self._cfgs['algo'] == 'SafeSLAC': + assert self._latent_model, 'The latent model must be provided.' + obs = obs.unsqueeze(0) + eval_observation_concator.append(obs, act) + + with torch.no_grad(): + feature = self._latent_model.encoder(eval_observation_concator.last_state) + z1_mean, z1_std = self._latent_model.z1_posterior( + torch.cat( + [ + feature.squeeze(), + z2.squeeze(), + eval_observation_concator.last_action, + ], + dim=-1, + ), + ) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = self._latent_model.z2_posterior( + torch.cat( + [z1.squeeze(), z2.squeeze(), eval_observation_concator.last_action], + dim=-1, + ), + ) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + obs = torch.cat((z1, z2), dim=-1).squeeze() + if 'Saute' in self._cfgs['algo'] or 'Simmer' in self._cfgs['algo']: self._safety_obs -= cost.unsqueeze(-1) / self._safety_budget self._safety_obs /= self._cfgs.algo_cfgs.saute_gamma @@ -399,7 +469,7 @@ def evaluate( and ep_cost >= self._cfgs.algo_cfgs.cost_limit ): terminated = torch.as_tensor(True) - length += 1 + length += info.get('num_step', 1) done = bool(terminated or truncated) @@ -472,6 +542,14 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc print(f'Saving the replay video to {save_replay_path},\n and the result to {result_path}.') print(self._dividing_line) + if self._cfgs['algo'] == 'SafeSLAC': + assert self._env.action_space.shape + eval_observation_concator: ObservationConcator = ObservationConcator( + self._cfgs['algo_cfgs']['latent_dim_1'] + self._cfgs['algo_cfgs']['latent_dim_2'], + self._env.action_space.shape, + self._cfgs['algo_cfgs']['num_sequences'], + ) + horizon = 1000 frames = [] obs, _ = self._env.reset() @@ -489,6 +567,19 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc step = 0 done = False ep_ret, ep_cost, length = 0.0, 0.0, 0.0 + + if self._cfgs['algo'] == 'SafeSLAC': + assert self._latent_model, 'The latent model must be provided.' + obs = obs.unsqueeze(0) + eval_observation_concator.reset_episode(obs) + with torch.no_grad(): + feature = self._latent_model.encoder(eval_observation_concator.last_state) + z1_mean, z1_std = self._latent_model.z1_posterior_init(feature) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = self._latent_model.z2_posterior_init(z1) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + obs = torch.cat((z1, z2), dim=-1).squeeze() + while ( not done and step <= max_render_steps ): # a big number to make sure the episode will end @@ -510,7 +601,34 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc raise ValueError( 'The policy must be provided or created before evaluating the agent.', ) - obs, rew, cost, terminated, truncated, _ = self._env.step(act) + obs, rew, cost, terminated, truncated, info = self._env.step(act) + if self._cfgs['algo'] == 'SafeSLAC': + assert self._latent_model, 'The latent model must be provided.' + obs = obs.unsqueeze(0) + eval_observation_concator.append(obs, act) + + with torch.no_grad(): + feature = self._latent_model.encoder(eval_observation_concator.last_state) + z1_mean, z1_std = self._latent_model.z1_posterior( + torch.cat( + [ + feature.squeeze(), + z2.squeeze(), + eval_observation_concator.last_action, + ], + dim=-1, + ), + ) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = self._latent_model.z2_posterior( + torch.cat( + [z1.squeeze(), z2.squeeze(), eval_observation_concator.last_action], + dim=-1, + ), + ) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + obs = torch.cat((z1, z2), dim=-1).squeeze() + if 'Saute' in self._cfgs['algo'] or 'Simmer' in self._cfgs['algo']: self._safety_obs -= cost.unsqueeze(-1) / self._safety_budget self._safety_obs /= self._cfgs.algo_cfgs.saute_gamma @@ -523,7 +641,7 @@ def render( # pylint: disable=too-many-locals,too-many-arguments,too-many-branc and ep_cost >= self._cfgs.algo_cfgs.cost_limit ): terminated = torch.as_tensor(True) - length += 1 + length += info.get('num_step', 1) if self._render_mode == 'rgb_array': frames.append(self._env.render())