From 17b8507002c8e576d9138e3d509d240e69e5ae8c Mon Sep 17 00:00:00 2001 From: Gaiejj Date: Sun, 5 May 2024 21:49:08 +0800 Subject: [PATCH] style: improve code style --- docs/source/spelling_wordlist.txt | 12 + omnisafe/adapter/__init__.py | 1 + omnisafe/adapter/offpolicy_latent_adapter.py | 106 ++++- omnisafe/algorithms/off_policy/safe_slac.py | 34 +- omnisafe/common/buffer/base.py | 43 +- omnisafe/common/buffer/offpolicy_buffer.py | 2 + omnisafe/common/latent.py | 396 +++++++++++++------ omnisafe/common/normalizer.py | 6 +- omnisafe/configs/off-policy/SafeSLAC.yaml | 10 +- omnisafe/envs/safety_gymnasium_vision_env.py | 21 + omnisafe/utils/model.py | 53 ++- omnisafe/utils/tools.py | 141 ++++--- 12 files changed, 619 insertions(+), 206 deletions(-) diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 460cabd1a..75d324245 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -486,3 +486,15 @@ UpdateDynamics mathbb meger Jupyter +LazyFrames +SLAC +Leibler +Kullback +slac +Tal +Nils +Simão +Hogewind +Yannick +Kachman +Thiago diff --git a/omnisafe/adapter/__init__.py b/omnisafe/adapter/__init__.py index ba768a7eb..eae6697d2 100644 --- a/omnisafe/adapter/__init__.py +++ b/omnisafe/adapter/__init__.py @@ -18,6 +18,7 @@ from omnisafe.adapter.modelbased_adapter import ModelBasedAdapter from omnisafe.adapter.offline_adapter import OfflineAdapter from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter +from omnisafe.adapter.offpolicy_latent_adapter import OffPolicyLatentAdapter from omnisafe.adapter.online_adapter import OnlineAdapter from omnisafe.adapter.onpolicy_adapter import OnPolicyAdapter from omnisafe.adapter.saute_adapter import SauteAdapter diff --git a/omnisafe/adapter/offpolicy_latent_adapter.py b/omnisafe/adapter/offpolicy_latent_adapter.py index 105ee8ad1..8db394abc 100644 --- a/omnisafe/adapter/offpolicy_latent_adapter.py +++ b/omnisafe/adapter/offpolicy_latent_adapter.py @@ -42,6 +42,18 @@ class OffPolicyLatentAdapter(OnlineAdapter): + """OffPolicy Adapter on Latent Space for OmniSafe. + + :class:`OffPolicyLatentAdapter` is used to adapt the vision-based environment to the off-policy + training. + + Args: + env_id (str): The environment id. + num_envs (int): The number of environments. + seed (int): The random seed. + cfgs (Config): The configuration. + """ + _current_obs: torch.Tensor _ep_ret: torch.Tensor _ep_cost: torch.Tensor @@ -54,8 +66,9 @@ def __init__( # pylint: disable=too-many-arguments seed: int, cfgs: Config, ) -> None: - """Initialize a instance of :class:`OffPolicyAdapter`.""" + """Initialize a instance of :class:`OffPolicyLatentAdapter`.""" super().__init__(env_id, num_envs, seed, cfgs) + assert self.action_space.shape self._observation_concator: ObservationConcator = ObservationConcator( self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2, self.action_space.shape, @@ -65,8 +78,9 @@ def __init__( # pylint: disable=too-many-arguments self._current_obs, _ = self.reset() self._max_ep_len: int = 1000 self._reset_log() - self.z1 = None - self.z2 = None + self.z1: torch.Tensor = torch.zeros(1) + self.z2: torch.Tensor = torch.zeros(1) + self._initialized: bool = False self._reset_sequence_queue = False def _wrapper( @@ -135,6 +149,13 @@ def eval_policy( # pylint: disable=too-many-locals agent: ConstraintActorQCritic, logger: Logger, ) -> None: + """Rollout the environment with deterministic agent action. + + Args: + episode (int): Number of episodes. + agent (ConstraintActorCritic): Agent. + logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. + """ for _ in range(episode): ep_ret, ep_cost, ep_len = 0.0, 0.0, 0 obs, _ = self._eval_env.reset() @@ -161,22 +182,37 @@ def eval_policy( # pylint: disable=too-many-locals }, ) - def pre_process(self, latent_model, concated_obs): + def pre_process( + self, + latent_model: CostLatentModel, + concated_obs: ObservationConcator, + ) -> torch.Tensor: + """Processes the concatenated observations to produce latent representation. + + Args: + latent_model (CostLatentModel): The latent model containing the encoder and decoder. + concated_obs (ObservationConcator): An object that encapsulates the concatenated observations. + + Returns: + A tensor combining the latent variables z1 and z2, representing the current state of + the system in the latent space. + """ with torch.no_grad(): feature = latent_model.encoder(concated_obs.last_state) - if self.z2 is None: + if not self._initialized: z1_mean, z1_std = latent_model.z1_posterior_init(feature) self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std z2_mean, z2_std = latent_model.z2_posterior_init(self.z1) self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std + self._initialized = True else: z1_mean, z1_std = latent_model.z1_posterior( - torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1), ) self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std z2_mean, z2_std = latent_model.z2_posterior( - torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1), ) self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std @@ -191,6 +227,17 @@ def rollout( # pylint: disable=too-many-locals logger: Logger, use_rand_action: bool, ) -> None: + """Rollout the environment and store the data in the buffer. + + Args: + rollout_step (int): Number of rollout steps. + agent (ConstraintActorCritic): Constraint actor-critic, including actor, reward critic, + and cost critic. + latent_model (CostLatentModel): Latent model, including encoder and decoder. + buffer (VectorOnPolicyBuffer): Vector on-policy buffer. + logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. + use_rand_action (bool): Whether to use random action. + """ for step in range(rollout_step): if not self._reset_sequence_queue: buffer.reset_sequence_queue(self._current_obs) @@ -198,10 +245,11 @@ def rollout( # pylint: disable=too-many-locals self._reset_sequence_queue = True if use_rand_action: - act = act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore + act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore else: act = agent.step( - self.pre_process(latent_model, self._observation_concator), deterministic=False + self.pre_process(latent_model, self._observation_concator), + deterministic=False, ) next_obs, reward, cost, terminated, truncated, info = self.step(act) @@ -217,8 +265,9 @@ def rollout( # pylint: disable=too-many-locals if done: self._log_metrics(logger, idx) self._reset_log(idx) - self.z1 = None - self.z2 = None + self.z1 = torch.zeros(1) + self.z2 = torch.zeros(1) + self._initialized = False self._reset_sequence_queue = False if 'final_observation' in info: real_next_obs[idx] = info['final_observation'][idx] @@ -239,11 +288,30 @@ def _log_value( cost: torch.Tensor, info: dict[str, Any], ) -> None: + """Log value. + + .. note:: + OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will + be stored in ``info['original_reward']`` and ``info['original_cost']``. + + Args: + reward (torch.Tensor): The immediate step reward. + cost (torch.Tensor): The immediate step cost. + info (dict[str, Any]): Some information logged by the environment. + """ self._ep_ret += info.get('original_reward', reward).cpu() self._ep_cost += info.get('original_cost', cost).cpu() self._ep_len += info.get('num_step', 1) def _log_metrics(self, logger: Logger, idx: int) -> None: + """Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``. + + Args: + logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. + idx (int): The index of the environment. + """ + if hasattr(self._env, 'spec_log'): + self._env.spec_log(logger) logger.store( { 'Metrics/EpRet': self._ep_ret[idx], @@ -253,6 +321,12 @@ def _log_metrics(self, logger: Logger, idx: int) -> None: ) def _reset_log(self, idx: int | None = None) -> None: + """Reset the episode return, episode cost and episode length. + + Args: + idx (int or None, optional): The index of the environment. Defaults to None + (single environment). + """ if idx is None: self._ep_ret = torch.zeros(self._env.num_envs) self._ep_cost = torch.zeros(self._env.num_envs) @@ -267,6 +341,16 @@ def reset( seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: + """Reset the environment and returns an initial observation. + + Args: + seed (int, optional): The random seed. Defaults to None. + options (dict[str, Any], optional): The options for the environment. Defaults to None. + + Returns: + observation: The initial observation of the space. + info: Some information logged by the environment. + """ obs, info = self._env.reset(seed=seed, options=options) self._observation_concator.reset_episode(obs) return obs, info diff --git a/omnisafe/algorithms/off_policy/safe_slac.py b/omnisafe/algorithms/off_policy/safe_slac.py index 61b9522f2..f8efad839 100644 --- a/omnisafe/algorithms/off_policy/safe_slac.py +++ b/omnisafe/algorithms/off_policy/safe_slac.py @@ -1,4 +1,4 @@ -# Copyright 2023 OmniSafe Team. All Rights Reserved. +# Copyright 2024 OmniSafe Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +36,16 @@ @registry.register # pylint: disable-next=too-many-instance-attributes, too-few-public-methods class SafeSLAC(SACLag): + """Safe SLAC algorithms for vision-based safe RL tasks. + + References: + - Title: Safe Reinforcement Learning From Pixels Using a Stochastic Latent Representation. + - Authors: Yannick Hogewind, Thiago D. Simão, Tal Kachman, Nils Jansen. + - URL: `Safe SLAC `_ + """ + + _is_latent_model_init_learned: bool + def _init(self) -> None: if self._cfgs.algo_cfgs.auto_alpha: self._target_entropy = -torch.prod(torch.Tensor(self._env.action_space.shape)).item() @@ -53,7 +63,7 @@ def _init(self) -> None: self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs) - self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( + self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( # type: ignore obs_space=self._env.observation_space, act_space=self._env.action_space, size=self._cfgs.algo_cfgs.size, @@ -64,7 +74,7 @@ def _init(self) -> None: self._is_latent_model_init_learned = False def _init_env(self) -> None: - self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( + self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( # type: ignore self._env_id, self._cfgs.train_cfgs.vector_env_nums, self._seed, @@ -96,6 +106,9 @@ def _init_env(self) -> None: def _init_model(self) -> None: self._cfgs.model_cfgs.critic['num_critics'] = 2 + assert self._env.observation_space.shape + assert self._env.action_space.shape + self._latent_model = CostLatentModel( obs_shape=self._env.observation_space.shape, act_shape=self._env.action_space.shape, @@ -114,9 +127,6 @@ def _init_model(self) -> None: epochs=self._epochs, ).to(self._device) - self._actor_critic = torch.compile(self._actor_critic) - self._latent_model = torch.compile(self._latent_model) - self._latent_model_optimizer = optim.Adam( self._latent_model.parameters(), lr=1e-4, @@ -218,7 +228,7 @@ def learn(self) -> tuple[float, float, float]: return ep_ret, ep_cost, ep_len - def _prepare_batch(self, obs_, action_): + def _prepare_batch(self, obs_: torch.Tensor, action_: torch.Tensor) -> tuple[torch.Tensor, ...]: with torch.no_grad(): feature_ = self._latent_model.encoder(obs_) z_ = torch.cat(self._latent_model.sample_posterior(feature_, action_)[2:4], dim=-1) @@ -266,9 +276,7 @@ def _update(self) -> None: self._update_actor(obs) self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak) - def _update_latent_model( - self, - ): + def _update_latent_model(self) -> None: data = self._buf.sample_batch(32) obs, act, reward, cost, done = ( data['obs'], @@ -280,7 +288,11 @@ def _update_latent_model( self._update_latent_count += 1 loss_kld, loss_image, loss_reward, loss_cost = self._latent_model.calculate_loss( - obs, act, reward, done, cost + obs, + act, + reward, + done, + cost, ) self._latent_model_optimizer.zero_grad() diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index 0e521eacf..a44afecdf 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -135,7 +135,19 @@ def store(self, **data: torch.Tensor) -> None: """ -class BaseSequenceBuffer(BaseBuffer): +class BaseSequenceBuffer(ABC): + r"""Abstract base class for sequence buffer. + + Attributes: + sequence_queue (SequenceQueue): The queue for storing the data. + + Args: + obs_space (OmnisafeSpace): The observation space. + act_space (OmnisafeSpace): The action space. + size (int): The size of the buffer. + device (torch.device): The device of the buffer. Defaults to ``torch.device('cpu')``. + """ + def __init__( self, obs_space: OmnisafeSpace, @@ -178,8 +190,37 @@ def __init__( self._observation_shape = obs_space.shape def add_field(self, name: str, shape: tuple[int, ...], dtype: torch.dtype) -> None: + """Add a field to the buffer. + + Args: + name (str): The name of the field. + shape (tuple of int): The shape of the field. + dtype (torch.dtype): The dtype of the field. + """ self.data[name] = torch.zeros( (self._size, self._num_sequences, *shape), dtype=dtype, device=self._device, ) + + @property + def device(self) -> torch.device: + """The device of the buffer.""" + return self._device + + @property + def size(self) -> int: + """The size of the buffer.""" + return self._size + + def __len__(self) -> int: + """Return the length of the buffer.""" + return self._size + + @abstractmethod + def store(self, **data: torch.Tensor) -> None: + """Store a transition in the buffer. + + Args: + data (torch.Tensor): The data to store. + """ diff --git a/omnisafe/common/buffer/offpolicy_buffer.py b/omnisafe/common/buffer/offpolicy_buffer.py index 9cd62dbbf..e02483b13 100644 --- a/omnisafe/common/buffer/offpolicy_buffer.py +++ b/omnisafe/common/buffer/offpolicy_buffer.py @@ -119,6 +119,8 @@ def sample_batch(self) -> dict[str, torch.Tensor]: class OffPolicySequenceBuffer(BaseSequenceBuffer): + """Sequence-based Replay buffer for off-policy algorithms.""" + def __init__( # pylint: disable=too-many-arguments self, obs_space: OmnisafeSpace, diff --git a/omnisafe/common/latent.py b/omnisafe/common/latent.py index 18e8ee921..8c9ac83e8 100644 --- a/omnisafe/common/latent.py +++ b/omnisafe/common/latent.py @@ -1,4 +1,24 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the Latent Model for Safe SLAC.""" + + +from __future__ import annotations + import math +from typing import Any import numpy as np import torch @@ -6,22 +26,48 @@ from torch.nn import functional as F -def calculate_kl_divergence(p_mean, p_std, q_mean, q_std): +def calculate_kl_divergence( + p_mean: torch.Tensor, + p_std: torch.Tensor, + q_mean: torch.Tensor, + q_std: torch.Tensor, +) -> torch.Tensor: + """Calculate the KL divergence between two normal distributions. + + Args: + p_mean (torch.Tensor): Mean of the first normal distribution. + p_std (torch.Tensor): Standard deviation of the first normal distribution. + q_mean (torch.Tensor): Mean of the second normal distribution. + q_std (torch.Tensor): Standard deviation of the second normal distribution. + + Returns: + torch.Tensor: The KL divergence between the two distributions. + """ var_ratio = (p_std / q_std).pow_(2) t1 = ((p_mean - q_mean) / q_std).pow_(2) return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) def build_mlp( - input_dim, - output_dim, - hidden_sizes=None, - hidden_activation=nn.Tanh(), - output_activation=None, -): + input_dim: int, + output_dim: int, + hidden_activation: nn.Module, + hidden_sizes: list[int] | None = None, +) -> nn.Sequential: + """Build a multi-layer perceptron (MLP) model using PyTorch. + + Args: + input_dim (int): Dimension of the input features. + output_dim (int): Dimension of the output. + hidden_sizes (list[int], optional): List of integers defining the number of units in each hidden layer. + hidden_activation (nn.Module): Activation function to use after each hidden layer. + + Returns: + nn.Sequential: The constructed MLP model. + """ if hidden_sizes is None: hidden_sizes = [64, 64] - layers = [] + layers: list[Any] = [] units = input_dim for next_units in hidden_sizes: layers.append(nn.Linear(units, next_units)) @@ -29,238 +75,335 @@ def build_mlp( units = next_units model = nn.Sequential(*layers) model.add_module('last_linear', nn.Linear(units, output_dim)) - if output_activation is not None: - model.add_module('output_activation', output_activation) return model -def initialize_weight(m): +def initialize_weight(m: nn.Module) -> None: + """Initializes the weights of the module using Xavier uniform initialization. + + Args: + m (nn.Module): The module whose weights need to be initialized. + """ if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): nn.init.xavier_uniform_(m.weight, gain=1.0) if m.bias is not None: nn.init.constant_(m.bias, 0) -class FixedGaussian(torch.nn.Module): - """ - Fixed diagonal gaussian distribution. +class FixedGaussian(nn.Module): + """Represents a fixed diagonal Gaussian distribution. + + Attributes: + output_dim (int): Dimension of the output distribution. + std (float): Standard deviation of the Gaussian distribution. """ - def __init__(self, output_dim, std) -> None: + def __init__(self, output_dim: int, std: float) -> None: + """Initialize an instance of the Fixed Gaussian.""" super().__init__() self.output_dim = output_dim self.std = std - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple: + """Generates a mean and standard deviation tensor based on the fixed parameters. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: Mean and standard deviation tensors. + """ mean = torch.zeros(x.size(0), self.output_dim, device=x.device) - std = torch.ones(x.size(0), self.output_dim, device=x.device).mul_(self.std) + std = torch.ones(x.size(0), self.output_dim, device=x.device) * self.std return mean, std -class Gaussian(torch.nn.Module): - """ - Diagonal gaussian distribution with state dependent variances. +class Gaussian(nn.Module): + """Represents a diagonal Gaussian distribution with state dependent variances. + + Attributes: + net (nn.Module): Neural network module to compute means and log standard deviations. """ - def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_sizes: list[int], + ) -> None: + """Initialize an instance of the Gaussian distribution.""" super().__init__() self.net = build_mlp( input_dim=input_dim, output_dim=2 * output_dim, - hidden_sizes=hidden_sizes, hidden_activation=nn.ELU(), + hidden_sizes=hidden_sizes, ).apply(initialize_weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple: + """Computes the mean and standard deviation for the Gaussian distribution. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: Mean and log standard deviation tensors. + """ if x.ndim == 3: - B, S, _ = x.size() - x = self.net(x.view(B * S, _)).view(B, S, -1) + batch_size, seq_length, _ = x.size() + x = self.net(x.view(batch_size * seq_length, _)).view(batch_size, seq_length, -1) else: x = self.net(x) mean, std = torch.chunk(x, 2, dim=-1) - std = F.softplus(std) + 1e-5 + std = F.softplus(std) + 1e-5 # pylint: disable=not-callable return mean, std -class Bernoulli(torch.nn.Module): - """ - Diagonal gaussian distribution with state dependent variances. +class Bernoulli(nn.Module): + """A module representing a Bernoulli distribution. + + This class builds a multi-layer perceptron (MLP) that outputs parameters for a Bernoulli + distribution. The output of the MLP is transformed using a sigmoid function to ensure it lies + between 0 and 1, representing probabilities. + + Attributes: + net (nn.Module): The neural network that builds the Bernoulli distribution. + + Args: + input_dim (int): The number of input features to the MLP. + output_dim (int): The number of output features from the MLP. + hidden_sizes (list[int, int]): The sizes of the hidden layers in the MLP. Defaults to + (256, 256). """ - def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_sizes: list[int], + ) -> None: + """Initializes the Bernoulli module with the specified architecture.""" super().__init__() self.net = build_mlp( input_dim=input_dim, output_dim=output_dim, - hidden_sizes=hidden_sizes, hidden_activation=nn.ELU(), + hidden_sizes=hidden_sizes, ).apply(initialize_weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Defines the forward pass of the Bernoulli module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor representing probabilities after applying the sigmoid function. + """ if x.ndim == 3: - B, S, _ = x.size() - x = self.net(x.view(B * S, _)).view(B, S, -1) + batch_size, seq_length, _ = x.size() + x = self.net(x.view(batch_size * seq_length, -1)).view(batch_size, seq_length, -1) else: x = self.net(x) return torch.sigmoid(x) class Decoder(torch.nn.Module): - """ - Decoder. + """The image processing decoder module. + + This decoder module that takes in a latent vector and outputs reconstructed images, + also outputs a tensor of the same shape with a constant standard deviation value. + + Attributes: + net (torch.nn.Sequential): The neural network layers. + std (float): The standard deviation value to be used for the output tensor. + + Args: + input_dim (int): Dimension of the input latent vector. Defaults to 288. + output_dim (int): The number of output channels. Defaults to 3. + std (float): Standard deviation for the generated tensor. Defaults to 1.0. """ - def __init__(self, input_dim=288, output_dim=3, std=1.0) -> None: + def __init__(self, input_dim: int = 288, output_dim: int = 3, std: float = 1.0) -> None: + """Initializes the decoder module.""" super().__init__() self.net = nn.Sequential( - # (32+256, 1, 1) -> (256, 4, 4) nn.ConvTranspose2d(input_dim, 256, 4), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (256, 4, 4) -> (128, 8, 8) nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (128, 8, 8) -> (64, 16, 16) nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (64, 16, 16) -> (32, 32, 32) nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), - # (32, 32, 32) -> (3, 64, 64) nn.ConvTranspose2d(32, output_dim, 5, 2, 2, 1), nn.LeakyReLU(inplace=True, negative_slope=0.2), ).apply(initialize_weight) self.std = std - def forward(self, x): - B, S, latent_dim = x.size() - x = x.view(B * S, latent_dim, 1, 1) + def forward(self, x: torch.Tensor) -> tuple: + """Forward pass for generating output image tensor and a tensor filled with std. + + Args: + x (torch.Tensor): Input latent vector tensor. + + Returns: + tuple: A tuple containing the reconstructed image tensor and a tensor filled with std. + """ + batch_size, seq_length, latent_dim = x.size() + x = x.view(batch_size * seq_length, latent_dim, 1, 1) x = self.net(x) - _, C, W, H = x.size() - x = x.view(B, S, C, W, H) + _, channels, width, height = x.size() + x = x.view(batch_size, seq_length, channels, width, height) return x, torch.ones_like(x).mul_(self.std) class Encoder(torch.nn.Module): - """ - Encoder. + """An encoder module that takes in images and outputs a latent vector representation. + + Attributes: + net (torch.nn.Sequential): The neural network layers. + + Args: + input_dim (int): Number of input channels in the image. Defaults to 3. + output_dim (int): Dimension of the output latent vector. Defaults to 256. """ - def __init__(self, input_dim=3, output_dim=256) -> None: + def __init__(self, input_dim: int = 3, output_dim: int = 256) -> None: + """Initialize the Encoder module.""" super().__init__() self.net = nn.Sequential( - # (3, 64, 64) -> (32, 32, 32) nn.Conv2d(input_dim, 32, 5, 2, 2), nn.ELU(inplace=True), - # (32, 32, 32) -> (64, 16, 16) nn.Conv2d(32, 64, 3, 2, 1), nn.ELU(inplace=True), - # (64, 16, 16) -> (128, 8, 8) nn.Conv2d(64, 128, 3, 2, 1), nn.ELU(inplace=True), - # (128, 8, 8) -> (256, 4, 4) nn.Conv2d(128, 256, 3, 2, 1), nn.ELU(inplace=True), - # (256, 4, 4) -> (256, 1, 1) nn.Conv2d(256, output_dim, 4), nn.ELU(inplace=True), ).apply(initialize_weight) - def forward(self, x): - B, S, C, H, W = x.size() - x = x.view(B * S, C, H, W) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass to transform input images into a flat latent vector representation. + + Args: + x (torch.Tensor): Input tensor of images. + + Returns: + torch.Tensor: Output tensor of latent vectors. + """ + batch_size, seq_length, channels, height, width = x.size() + x = x.view(batch_size * seq_length, channels, height, width) x = self.net(x) - return x.view(B, S, -1) + return x.view(batch_size, seq_length, -1) -class CostLatentModel(torch.nn.Module): - """ - Stochastic latent variable model to estimate latent dynamics, reward and cost. +# pylint: disable-next=too-many-instance-attributes +class CostLatentModel(nn.Module): + """The latent model for cost prediction. + + Stochastic latent variable model that estimates latent dynamics, rewards, and costs + for a given observation and action space using variational inference techniques. + + Args: + obs_shape (tuple of int): Shape of the observations. + act_shape (tuple of int): Shape of the actions. + feature_dim (int): Dimension of the feature vector from observations. Defaults to 256. + latent_dim_1 (int): Dimension of the first set of latent variables. Defaults to 32. + latent_dim_2 (int): Dimension of the second set of latent variables. Defaults to 256. + hidden_sizes (list of int): Sizes of hidden layers in the networks. + image_noise (float): Standard deviation of noise in image reconstruction. Defaults to 0.1. """ def __init__( self, - obs_shape, - act_shape, - feature_dim=256, - latent_dim_1=32, - latent_dim_2=256, - hidden_sizes=(256, 256), - image_noise=0.1, + obs_shape: tuple[int, ...], + act_shape: tuple[int, ...], + hidden_sizes: list[int], + feature_dim: int = 256, + latent_dim_1: int = 32, + latent_dim_2: int = 256, + image_noise: float = 0.1, ) -> None: + """Initialize the instance of CostLatentModel.""" super().__init__() self.bceloss = torch.nn.BCELoss(reduction='none') - # p(z1(0)) = N(0, I) self.z1_prior_init = FixedGaussian(latent_dim_1, 1.0) - # p(z2(0) | z1(0)) self.z2_prior_init = Gaussian(latent_dim_1, latent_dim_2, hidden_sizes) - # p(z1(t+1) | z2(t), a(t)) self.z1_prior = Gaussian( - latent_dim_2 + act_shape[0], - latent_dim_1, + int(latent_dim_2 + act_shape[0]), + int(latent_dim_1), hidden_sizes, ) - # p(z2(t+1) | z1(t+1), z2(t), a(t)) self.z2_prior = Gaussian( - latent_dim_1 + latent_dim_2 + act_shape[0], - latent_dim_2, + int(latent_dim_1 + latent_dim_2 + act_shape[0]), + int(latent_dim_2), hidden_sizes, ) - # q(z1(0) | feat(0)) self.z1_posterior_init = Gaussian(feature_dim, latent_dim_1, hidden_sizes) - # q(z2(0) | z1(0)) = p(z2(0) | z1(0)) self.z2_posterior_init = self.z2_prior_init - # q(z1(t+1) | feat(t+1), z2(t), a(t)) self.z1_posterior = Gaussian( - feature_dim + latent_dim_2 + act_shape[0], - latent_dim_1, + int(feature_dim + latent_dim_2 + act_shape[0]), + int(latent_dim_1), hidden_sizes, ) - # q(z2(t+1) | z1(t+1), z2(t), a(t)) = p(z2(t+1) | z1(t+1), z2(t), a(t)) self.z2_posterior = self.z2_prior - # p(r(t) | z1(t), z2(t), a(t), z1(t+1), z2(t+1)) self.reward = Gaussian( - 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + int(2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0]), 1, hidden_sizes, ) self.cost = Bernoulli( - 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + int(2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0]), 1, hidden_sizes, ) - # feat(t) = Encoder(x(t)) self.encoder = Encoder(obs_shape[0], feature_dim) - # p(x(t) | z1(t), z2(t)) self.decoder = Decoder( - latent_dim_1 + latent_dim_2, - obs_shape[0], + int(latent_dim_1 + latent_dim_2), + int(obs_shape[0]), std=np.sqrt(image_noise), ) self.apply(initialize_weight) - def sample_prior(self, actions_, z2_post_): - # p(z1(0)) = N(0, I) + def sample_prior(self, actions_: torch.Tensor, z2_post_: torch.Tensor) -> tuple: + """Sample the prior distribution for latent variables using initial and recurrent models. + + Args: + actions_ (torch.Tensor): The actions taken at each step of the sequence. + z2_post_ (torch.Tensor): The posterior samples of the second set of latent variables. + + Returns: + tuple: A tuple containing means and standard deviations for the first set of latent variables. + """ z1_mean_init, z1_std_init = self.z1_prior_init(actions_[:, 0]) - # p(z1(t) | z2(t-1), a(t-1)) z1_mean_, z1_std_ = self.z1_prior( - torch.cat([z2_post_[:, : actions_.size(1)], actions_], dim=-1) + torch.cat([z2_post_[:, : actions_.size(1)], actions_], dim=-1), ) - # Concatenate initial and consecutive latent variables z1_mean_ = torch.cat([z1_mean_init.unsqueeze(1), z1_mean_], dim=1) z1_std_ = torch.cat([z1_std_init.unsqueeze(1), z1_std_], dim=1) return (z1_mean_, z1_std_) - def sample_posterior(self, features_, actions_): - # p(z1(0)) = N(0, I) + def sample_posterior(self, features_: torch.Tensor, actions_: torch.Tensor) -> tuple: + """Sample the posterior distribution for latent variables. + + Args: + features_ (torch.Tensor): The features extracted from observations at each timestep. + actions_ (torch.Tensor): The actions taken at each step of the sequence. + + Returns: + tuple: A tuple of tensors containing means, standard deviations, and samples of the latent variables. + """ z1_mean, z1_std = self.z1_posterior_init(features_[:, 0]) z1 = z1_mean + torch.randn_like(z1_std) * z1_std - # p(z2(0) | z1(0)) z2_mean, z2_std = self.z2_posterior_init(z1) z2 = z2_mean + torch.randn_like(z2_std) * z2_std @@ -270,12 +413,10 @@ def sample_posterior(self, features_, actions_): z2_ = [z2] for t in range(1, actions_.size(1) + 1): - # q(z1(t) | feat(t), z2(t-1), a(t-1)) z1_mean, z1_std = self.z1_posterior( - torch.cat([features_[:, t], z2, actions_[:, t - 1]], dim=1) + torch.cat([features_[:, t], z2, actions_[:, t - 1]], dim=1), ) z1 = z1_mean + torch.randn_like(z1_std) * z1_std - # q(z2(t) | z1(t), z2(t-1), a(t-1)) z2_mean, z2_std = self.z2_posterior(torch.cat([z1, z2, actions_[:, t - 1]], dim=1)) z2 = z2_mean + torch.randn_like(z2_std) * z2_std @@ -290,44 +431,63 @@ def sample_posterior(self, features_, actions_): z2_ = torch.stack(z2_, dim=1) return (z1_mean_, z1_std_, z1_, z2_) - # - def calculate_loss(self, state_, action_, reward_, done_, cost_): - # Calculate the sequence of features. - feature_ = self.encoder(state_) + # pylint: disable-next=too-many-locals + def calculate_loss( + self, + state_: torch.Tensor, + action_: torch.Tensor, + reward_: torch.Tensor, + done_: torch.Tensor, + cost_: torch.Tensor, + ) -> tuple: + """Calculate the loss for the model. + + Args: + state_ (torch.Tensor): Observed states over a sequence of timesteps. + action_ (torch.Tensor): Actions taken at each timestep. + reward_ (torch.Tensor): Observed rewards at each timestep. + done_ (torch.Tensor): Done flags indicating the end of an episode. + cost_ (torch.Tensor): Observed costs at each timestep. + + Returns: + tuple: A tuple containing the KL divergence loss, image reconstruction loss, + reward prediction loss, and cost classification loss. + """ + feature_ = self.forward(state_) - # Sample from latent variable model. z1_mean_post_, z1_std_post_, z1_, z2_ = self.sample_posterior(feature_, action_) z1_mean_pri_, z1_std_pri_ = self.sample_prior(action_, z2_) - # Calculate KL divergence loss. loss_kld = ( calculate_kl_divergence(z1_mean_post_, z1_std_post_, z1_mean_pri_, z1_std_pri_) .mean(dim=0) .sum() ) - # Prediction loss of images. z_ = torch.cat([z1_, z2_], dim=-1) state_mean_, state_std_ = self.decoder(z_) state_noise_ = (state_ - state_mean_) / (state_std_ + 1e-8) log_likelihood_ = (-0.5 * state_noise_.pow(2) - state_std_.log()) - 0.5 * math.log( - 2 * math.pi + 2 * math.pi, ) loss_image = -log_likelihood_.mean(dim=0).sum() - # Prediction loss of rewards. x = torch.cat([z_[:, :-1], action_, z_[:, 1:]], dim=-1) - B, S, X = x.shape - reward_mean_, reward_std_ = self.reward(x.view(B * S, X)) - reward_mean_ = reward_mean_.view(B, S, 1) - reward_std_ = reward_std_.view(B, S, 1) + batch_size, seq_length, concated_shape = x.shape + reward_mean_, reward_std_ = self.reward(x.view(batch_size * seq_length, concated_shape)) + reward_mean_ = reward_mean_.view(batch_size, seq_length, 1) + reward_std_ = reward_std_.view(batch_size, seq_length, 1) reward_noise_ = (reward_ - reward_mean_) / (reward_std_ + 1e-8) log_likelihood_reward_ = (-0.5 * reward_noise_.pow(2) - reward_std_.log()) - 0.5 * math.log( - 2 * math.pi + 2 * math.pi, ) loss_reward = -log_likelihood_reward_.mul_(1 - done_).mean(dim=0).sum() - p = self.cost(x.view(B * S, X)).view(B, S, 1) + p = self.cost(x.view(batch_size * seq_length, concated_shape)).view( + batch_size, + seq_length, + 1, + ) q = 1 - p weight_p = 100 binary_cost_ = torch.sign(cost_) @@ -342,3 +502,7 @@ def calculate_loss(self, state_, action_, reward_, done_, cost_): ) return loss_kld, loss_image, loss_reward, loss_cost + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """Get the encoded state of observation.""" + return self.encoder(obs) diff --git a/omnisafe/common/normalizer.py b/omnisafe/common/normalizer.py index 38bbe5f77..e90909089 100644 --- a/omnisafe/common/normalizer.py +++ b/omnisafe/common/normalizer.py @@ -144,12 +144,12 @@ def load_state_dict( strict: bool = True, assign: bool = False, ) -> Any: - """Load the state_dict to the normalizer. + """Load the parameters to the normalizer. Args: state_dict (Mapping[str, Any]): The state_dict to be loaded. - strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`. - Defaults to True. + strict (bool, optional): Whether to strictly enforce that the keys in + :attr:`state_dict`. Defaults to True. Returns: The loaded normalizer. diff --git a/omnisafe/configs/off-policy/SafeSLAC.yaml b/omnisafe/configs/off-policy/SafeSLAC.yaml index a11626fe3..48df653f4 100644 --- a/omnisafe/configs/off-policy/SafeSLAC.yaml +++ b/omnisafe/configs/off-policy/SafeSLAC.yaml @@ -42,16 +42,14 @@ defaults: image_noise: 0.4 # list of sizes for hidden layers of dynamics model hidden_sizes: [256, 256] - # dimensionality of feature vectors initially - feature_dim: 256 - # dimensionality of the first latent space + # dimension of the first latent space latent_dim_1: 32 - # dimensionality of the second latent space + # dimension of the second latent space latent_dim_2: 200 - # dimensionality of feature vectors + # dimension of feature vectors feature_dim: 200 # number of steps to update the policy - steps_per_epoch: 2000 + steps_per_epoch: 2000 # number of steps per sample update_cycle: 100 # number of iterations to update the policy diff --git a/omnisafe/envs/safety_gymnasium_vision_env.py b/omnisafe/envs/safety_gymnasium_vision_env.py index 778ddacb4..4ab6cc48e 100644 --- a/omnisafe/envs/safety_gymnasium_vision_env.py +++ b/omnisafe/envs/safety_gymnasium_vision_env.py @@ -29,6 +29,27 @@ @env_register class SafetyGymnasiumVisionEnv(CMDP): + """Safety Gymnasium Vision-Based Environment. + + Args: + env_id (str): Environment id. + num_envs (int, optional): Number of environments. Defaults to 1. + device (torch.device, optional): Device to store the data. Defaults to + ``torch.device('cpu')``. + + Keyword Args: + render_mode (str, optional): The render mode ranges from 'human' to 'rgb_array' and 'rgb_array_list'. + Defaults to 'rgb_array'. + camera_name (str, optional): The camera name. + camera_id (int, optional): The camera id. + width (int, optional): The width of the rendered image. Defaults to 256. + height (int, optional): The height of the rendered image. Defaults to 256. + + Attributes: + need_auto_reset_wrapper (bool): Whether to use auto reset wrapper. + need_time_limit_wrapper (bool): Whether to use time limit wrapper. + """ + need_auto_reset_wrapper: bool = False need_time_limit_wrapper: bool = False diff --git a/omnisafe/utils/model.py b/omnisafe/utils/model.py index 5e1b44a0c..cdc52b1b2 100644 --- a/omnisafe/utils/model.py +++ b/omnisafe/utils/model.py @@ -115,13 +115,36 @@ def build_mlp_network( class ObservationConcator: - def __init__(self, state_shape, action_shape, num_sequences, device=DEVICE_CPU) -> None: + """A class designed to concatenate observations and actions over a specified time steps.""" + + def __init__( + self, + state_shape: tuple, + action_shape: tuple, + num_sequences: int, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize the ObservationConcator with given shapes and device. + + Args: + state_shape (tuple): Shape of the state space. + action_shape (tuple): Shape of the action space. + num_sequences (int): Number of sequences to maintain in the history. + device (str): The device (CPU/GPU) on which to create tensors. + """ self.state_shape = state_shape self.action_shape = action_shape self.num_sequences = num_sequences self.device = device + self._state: deque = deque(maxlen=self.num_sequences) + self._action: deque = deque(maxlen=self.num_sequences - 1) - def reset_episode(self, state): + def reset_episode(self, state: torch.Tensor) -> None: + """Reset the history of states and actions for a new episode. + + Args: + state (torch.Tensor): The initial state for the new episode. + """ self._state = deque(maxlen=self.num_sequences) self._action = deque(maxlen=self.num_sequences - 1) for _ in range(self.num_sequences - 1): @@ -133,22 +156,30 @@ def reset_episode(self, state): ) self._state.append(state) - def append(self, state, action): + def append(self, state: torch.Tensor, action: torch.Tensor) -> None: + """Append a new state and action to the queue. + + Args: + state (torch.Tensor): State to be appended. + action (torch.Tensor): Action to be appended. + """ self._state.append(state) self._action.append(action) @property - def state(self): - return self._state[None, ...] + def last_state(self) -> torch.Tensor: + """Returns the most recent state. - @property - def last_state(self): + Returns: + torch.Tensor: The most recent state. + """ return self._state[-1][None, ...] @property - def action(self): - return self._action.reshape(1, -1) + def last_action(self) -> torch.Tensor: + """Returns the most recent action. - @property - def last_action(self): + Returns: + torch.Tensor: The most recent action. + """ return self._action[-1] diff --git a/omnisafe/utils/tools.py b/omnisafe/utils/tools.py index 8fc0818da..b3e5a9fad 100644 --- a/omnisafe/utils/tools.py +++ b/omnisafe/utils/tools.py @@ -29,7 +29,6 @@ import torch.backends.cudnn import yaml from rich.console import Console -from torch.version import cuda as cuda_version from omnisafe.typing import DEVICE_CPU, OmnisafeSpace @@ -154,15 +153,6 @@ def seed_all(seed: int) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - try: - # torch.use_deterministic_algorithms(True) - # torch.backends.cudnn.enabled = False - # torch.backends.cudnn.benchmark = False - if cuda_version is not None and float(cuda_version) >= 10.2: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' - os.environ['PYTHONHASHSEED'] = str(seed) - except AttributeError: # pragma: no cover - pass def custom_cfgs_to_dict(key_list: str, value: Any) -> dict[str, Any]: @@ -362,81 +352,138 @@ def get_device(device: torch.device | str | int = DEVICE_CPU) -> torch.device: # Force conversion to torch.device device = torch.device(device) - # Cuda not available if not torch.cuda.is_available() and device.type == torch.device('cuda').type: return torch.device('cpu') return device -def create_feature_actions(feature_, action_): - N = feature_.size(0) - # Flatten sequence of features. - # f (batch_size, (num_sequences)*feature_dim) - f = feature_[:, :-1].view(N, -1) - n_f = feature_[:, 1:].view(N, -1) - # Flatten sequence of actions. - a = action_[:, :-1].view(N, -1) - n_a = action_[:, 1:].view(N, -1) - # Concatenate feature and action. - fa = torch.cat([f, a], dim=-1) - n_fa = torch.cat([n_f, n_a], dim=-1) - return fa, n_fa +class LazyFrames: + """A class that lazily handles a list of frames. + Attributes: + _frames (List[Any]): Private storage for frame data, stored as a list. + """ -class LazyFrames: - def __init__(self, frames) -> None: + def __init__(self, frames: list[Any]) -> None: + """Initializes a new instance of LazyFrames. + + Args: + frames (list[Any]): A list of frame data. + """ self._frames = list(frames) - def __array__(self, dtype): + def __array__(self, dtype: type[np.dtype]) -> np.ndarray: + """Returns an array representation of the frames with the specified data type. + + Args: + dtype (Type[np.dtype]): The desired data type for the numpy array. + + Returns: + np.ndarray: An array representation of the frames. + """ return np.array(self._frames, dtype=dtype) def __len__(self) -> int: + """Returns the number of frames. + + Returns: + int: The number of frames. + """ return len(self._frames) class SequenceQueue: - def __init__(self, obs_space: OmnisafeSpace, num_sequences: int = 8, device=DEVICE_CPU) -> None: + """A queue that manages sequences of observations, actions, rewards, etc. for RL agents. + + Attributes: + num_sequences (int): The number of sequences to store. + _reset_episode (bool): Flag to indicate whether the queue needs to be reset. + _obs_space (OmnisafeSpace): The space of the observations. + _device (str): The device (CPU or GPU) on which to process the tensors. + data (dict): A dictionary containing observations, actions, rewards, costs and terminated + or truncated flags. + """ + + def __init__( + self, + obs_space: OmnisafeSpace, + num_sequences: int = 8, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize a new instance of SequenceQueue. + + Args: + obs_space (OmnisafeSpace): The observation space of the environment. + num_sequences (int, optional): Maximum number of sequences to store. Defaults to 8. + device (str, optional): The computation device ('cpu' or 'gpu'). Defaults to 'cpu'. + """ self.num_sequences = num_sequences self._reset_episode = False self._obs_space = obs_space self._device = device - self.data = {} + self.data: dict[str, Any] = {} self.data['obs'] = deque(maxlen=self.num_sequences + 1) self.data['act'] = deque(maxlen=self.num_sequences) self.data['reward'] = deque(maxlen=self.num_sequences) self.data['done'] = deque(maxlen=self.num_sequences) self.data['cost'] = deque(maxlen=self.num_sequences) - def reset_sequence_queue(self, obs): - for k in self.data: - self.data[k].clear() + def reset_sequence_queue(self, obs: torch.Tensor) -> None: + """Reset the sequence queue with the initial observation. + + Args: + obs (torch.Tensor): The initial observation from the environment. + """ + for _, v in self.data.items(): + v.clear() self._reset_episode = True self.data['obs'].append(obs.detach().cpu().numpy()) - def append(self, **data: torch.Tensor): - assert self._reset_episode, self._reset_episode + def append(self, **data: torch.Tensor) -> None: + """Append new data to the queue. + + Args: + data (torch.Tensor): Keyword arguments where keys are data types ('obs', 'act', etc.) + and values are torch Tensors. + """ + assert self._reset_episode, 'Reset the sequence queue before appending data.' for key, value in data.items(): self.data[key].append(value.detach().cpu().numpy()) - def _process_get(self, key: str) -> LazyFrames | torch.Tensor: + def _process_get(self, key: str) -> Any: + """Process the requested data for retrieval. + + Args: + key (str): The key of the data to retrieve. + + Returns: + Union[LazyFrames, torch.Tensor]: The processed data, either as LazyFrames or a torch Tensor. + """ if key == 'obs': return np.array(LazyFrames(self.data['obs']), dtype=np.float32).swapaxes(0, 1).squeeze() - else: - return torch.tensor( - np.array(self.data[key], dtype=np.float32), - dtype=torch.float32, - device=self._device, - ) - - def get(self): + return torch.tensor( + np.array(self.data[key], dtype=np.float32), + dtype=torch.float32, + device=self._device, + ) + + def get(self) -> dict: + """Retrieve all stored data. + + Returns: + dict: A dictionary of processed data for each stored type. + """ return {key: self._process_get(key) for key in self.data} - def is_empty(self): - return len(self.data['reward']) == 0 - - def is_full(self): + def is_full(self) -> bool: + """Return whether the sequence queue is full.""" return len(self.data['reward']) == self.num_sequences def __len__(self) -> int: + """Get the number of rewards stored, which reflects the number of complete sequences. + + Returns: + int: The number of sequences. + """ return len(self.data['reward'])