From c430010a93c7c95c984fe57d835d463058d4dae1 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Apr 2023 14:53:09 +0000 Subject: [PATCH] refactor(common/logger): refactor and simplify logger storage logic --- omnisafe/adapter/offpolicy_adapter.py | 4 ++-- omnisafe/adapter/onpolicy_adapter.py | 6 +++--- omnisafe/adapter/saute_adapter.py | 2 +- omnisafe/algorithms/off_policy/ddpg.py | 16 ++++++++-------- omnisafe/algorithms/off_policy/ddpg_lag.py | 4 ++-- omnisafe/algorithms/off_policy/sac.py | 10 +++++----- omnisafe/algorithms/off_policy/sac_lag.py | 4 ++-- omnisafe/algorithms/off_policy/td3.py | 2 +- omnisafe/algorithms/off_policy/td3_lag.py | 4 ++-- omnisafe/algorithms/on_policy/base/natural_pg.py | 6 +++--- .../algorithms/on_policy/base/policy_gradient.py | 14 +++++++------- omnisafe/algorithms/on_policy/base/trpo.py | 4 ++-- omnisafe/algorithms/on_policy/first_order/cup.py | 4 ++-- .../algorithms/on_policy/first_order/focops.py | 2 +- .../algorithms/on_policy/naive_lagrange/pdo.py | 2 +- .../on_policy/naive_lagrange/ppo_lag.py | 2 +- .../algorithms/on_policy/naive_lagrange/rcpo.py | 2 +- .../on_policy/naive_lagrange/trpo_lag.py | 2 +- .../algorithms/on_policy/penalty_function/ipo.py | 2 +- .../algorithms/on_policy/penalty_function/p3o.py | 2 +- .../on_policy/pid_lagrange/cppo_pid.py | 2 +- .../on_policy/pid_lagrange/trpo_pid.py | 2 +- .../algorithms/on_policy/second_order/cpo.py | 4 ++-- .../algorithms/on_policy/second_order/pcpo.py | 2 +- omnisafe/common/logger.py | 11 +++++++++-- 25 files changed, 61 insertions(+), 54 deletions(-) diff --git a/omnisafe/adapter/offpolicy_adapter.py b/omnisafe/adapter/offpolicy_adapter.py index 0fdacdc64..c56276ac9 100644 --- a/omnisafe/adapter/offpolicy_adapter.py +++ b/omnisafe/adapter/offpolicy_adapter.py @@ -111,7 +111,7 @@ def eval_policy( # pylint: disable=too-many-locals done = terminated or truncated if done: logger.store( - **{ + { 'Metrics/TestEpRet': ep_ret, 'Metrics/TestEpCost': ep_cost, 'Metrics/TestEpLen': ep_len, @@ -197,7 +197,7 @@ def _log_metrics(self, logger: Logger, idx: int) -> None: idx (int): The index of the environment. """ logger.store( - **{ + { 'Metrics/EpRet': self._ep_ret[idx], 'Metrics/EpCost': self._ep_cost[idx], 'Metrics/EpLen': self._ep_len[idx], diff --git a/omnisafe/adapter/onpolicy_adapter.py b/omnisafe/adapter/onpolicy_adapter.py index fc50a23e0..4b2940ef9 100644 --- a/omnisafe/adapter/onpolicy_adapter.py +++ b/omnisafe/adapter/onpolicy_adapter.py @@ -96,8 +96,8 @@ def roll_out( # pylint: disable=too-many-locals self._log_value(reward=reward, cost=cost, info=info) if self._cfgs.algo_cfgs.use_cost: - logger.store(**{'Value/cost': value_c}) - logger.store(**{'Value/reward': value_r}) + logger.store({'Value/cost': value_c}) + logger.store({'Value/reward': value_r}) buffer.store( obs=obs, @@ -169,7 +169,7 @@ def _log_metrics(self, logger: Logger, idx: int) -> None: idx (int): The index of the environment. """ logger.store( - **{ + { 'Metrics/EpRet': self._ep_ret[idx], 'Metrics/EpCost': self._ep_cost[idx], 'Metrics/EpLen': self._ep_len[idx], diff --git a/omnisafe/adapter/saute_adapter.py b/omnisafe/adapter/saute_adapter.py index 8607f2772..df381ef28 100644 --- a/omnisafe/adapter/saute_adapter.py +++ b/omnisafe/adapter/saute_adapter.py @@ -125,4 +125,4 @@ def _reset_log(self, idx: int | None = None) -> None: def _log_metrics(self, logger: Logger, idx: int) -> None: super()._log_metrics(logger, idx) - logger.store(**{'Metrics/EpBudget': self._ep_budget[idx]}) + logger.store({'Metrics/EpBudget': self._ep_budget[idx]}) diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py index c747b35d5..c7d41e1c1 100644 --- a/omnisafe/algorithms/off_policy/ddpg.py +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -192,8 +192,8 @@ def learn(self) -> tuple[int | float, ...]: logger=self._logger, ) - self._logger.store(**{'Time/Update': update_time}) - self._logger.store(**{'Time/Rollout': roll_out_time}) + self._logger.store({'Time/Update': update_time}) + self._logger.store({'Time/Rollout': roll_out_time}) if ( step > self._cfgs.algo_cfgs.start_learning_steps @@ -202,7 +202,7 @@ def learn(self) -> tuple[int | float, ...]: self._actor_critic.actor_scheduler.step() self._logger.store( - **{ + { 'TotalEnvSteps': step + 1, 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), @@ -265,7 +265,7 @@ def _update_reward_critic( for param in self._actor_critic.reward_critic.parameters(): loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff self._logger.store( - **{ + { 'Loss/Loss_reward_critic': loss.mean().item(), 'Value/reward_critic': q_value_r.mean().item(), }, @@ -312,7 +312,7 @@ def _update_cost_critic( self._actor_critic.cost_critic_optimizer.step() self._logger.store( - **{ + { 'Loss/Loss_cost_critic': loss.mean().item(), 'Value/cost_critic': q_value_c.mean().item(), }, @@ -332,7 +332,7 @@ def _update_actor( # pylint: disable=too-many-arguments ) self._actor_critic.actor_optimizer.step() self._logger.store( - **{ + { 'Loss/Loss_pi': loss.mean().item(), }, ) @@ -346,7 +346,7 @@ def _loss_pi( def _log_when_not_update(self) -> None: self._logger.store( - **{ + { 'Loss/Loss_reward_critic': 0.0, 'Loss/Loss_pi': 0.0, 'Value/reward_critic': 0.0, @@ -354,7 +354,7 @@ def _log_when_not_update(self) -> None: ) if self._cfgs.algo_cfgs.use_cost: self._logger.store( - **{ + { 'Loss/Loss_cost_critic': 0.0, 'Value/cost_critic': 0.0, }, diff --git a/omnisafe/algorithms/off_policy/ddpg_lag.py b/omnisafe/algorithms/off_policy/ddpg_lag.py index ccd81c2cd..b198525dc 100644 --- a/omnisafe/algorithms/off_policy/ddpg_lag.py +++ b/omnisafe/algorithms/off_policy/ddpg_lag.py @@ -48,7 +48,7 @@ def _update(self) -> None: Jc = self._logger.get_stats('Metrics/EpCost')[0] self._lagrange.update_lagrange_multiplier(Jc) self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, ) @@ -68,7 +68,7 @@ def _loss_pi( def _log_when_not_update(self) -> None: super()._log_when_not_update() self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, ) diff --git a/omnisafe/algorithms/off_policy/sac.py b/omnisafe/algorithms/off_policy/sac.py index 9a0b9fb2d..08f7a585a 100644 --- a/omnisafe/algorithms/off_policy/sac.py +++ b/omnisafe/algorithms/off_policy/sac.py @@ -117,7 +117,7 @@ def _update_reward_critic( distributed.avg_grads(self._actor_critic.reward_critic) self._actor_critic.reward_critic_optimizer.step() self._logger.store( - **{ + { 'Loss/Loss_reward_critic': loss.mean().item(), 'Value/reward_critic': q1_value_r.mean().item(), }, @@ -139,12 +139,12 @@ def _update_actor( alpha_loss.backward() self._alpha_optimizer.step() self._logger.store( - **{ + { 'Loss/alpha_loss': alpha_loss.mean().item(), }, ) self._logger.store( - **{ + { 'Value/alpha': self._alpha, }, ) @@ -161,13 +161,13 @@ def _loss_pi( def _log_when_not_update(self) -> None: super()._log_when_not_update() self._logger.store( - **{ + { 'Value/alpha': self._alpha, }, ) if self._cfgs.algo_cfgs.auto_alpha: self._logger.store( - **{ + { 'Loss/alpha_loss': 0.0, }, ) diff --git a/omnisafe/algorithms/off_policy/sac_lag.py b/omnisafe/algorithms/off_policy/sac_lag.py index 7d17768e3..f5e8eda86 100644 --- a/omnisafe/algorithms/off_policy/sac_lag.py +++ b/omnisafe/algorithms/off_policy/sac_lag.py @@ -46,7 +46,7 @@ def _update(self) -> None: Jc = self._logger.get_stats('Metrics/EpCost')[0] self._lagrange.update_lagrange_multiplier(Jc) self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, ) @@ -67,7 +67,7 @@ def _loss_pi( def _log_when_not_update(self) -> None: super()._log_when_not_update() self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, ) diff --git a/omnisafe/algorithms/off_policy/td3.py b/omnisafe/algorithms/off_policy/td3.py index 14c661362..277d19fdf 100644 --- a/omnisafe/algorithms/off_policy/td3.py +++ b/omnisafe/algorithms/off_policy/td3.py @@ -103,7 +103,7 @@ def _update_reward_critic( distributed.avg_grads(self._actor_critic.reward_critic) self._actor_critic.reward_critic_optimizer.step() self._logger.store( - **{ + { 'Loss/Loss_reward_critic': loss.mean().item(), 'Value/reward_critic': q1_value_r.mean().item(), }, diff --git a/omnisafe/algorithms/off_policy/td3_lag.py b/omnisafe/algorithms/off_policy/td3_lag.py index feab75f72..f77766c15 100644 --- a/omnisafe/algorithms/off_policy/td3_lag.py +++ b/omnisafe/algorithms/off_policy/td3_lag.py @@ -46,7 +46,7 @@ def _update(self) -> None: Jc = self._logger.get_stats('Metrics/EpCost')[0] self._lagrange.update_lagrange_multiplier(Jc) self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, ) @@ -65,7 +65,7 @@ def _loss_pi( def _log_when_not_update(self) -> None: super()._log_when_not_update() self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, ) diff --git a/omnisafe/algorithms/on_policy/base/natural_pg.py b/omnisafe/algorithms/on_policy/base/natural_pg.py index 6dff7ac51..4c669f03e 100644 --- a/omnisafe/algorithms/on_policy/base/natural_pg.py +++ b/omnisafe/algorithms/on_policy/base/natural_pg.py @@ -109,7 +109,7 @@ def _fvp(self, params: torch.Tensor) -> torch.Tensor: distributed.avg_tensor(flat_grad_grad_kl) self._logger.store( - **{ + { 'Train/KL': kl.item(), }, ) @@ -164,7 +164,7 @@ def _update_actor( # pylint: disable=too-many-arguments,too-many-locals loss, info = self._loss_pi(obs, act, logp, adv) self._logger.store( - **{ + { 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], 'Train/PolicyStd': info['std'], @@ -225,7 +225,7 @@ def _update(self) -> None: self._update_cost_critic(obs, target_value_c) self._logger.store( - **{ + { 'Train/StopIter': self._cfgs.algo_cfgs.update_iters, 'Value/Adv': adv_r.mean().item(), }, diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index e18548fa8..9757ad85a 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -256,11 +256,11 @@ def learn(self) -> tuple[int | float, ...]: buffer=self._buf, logger=self._logger, ) - self._logger.store(**{'Time/Rollout': time.time() - roll_out_time}) + self._logger.store({'Time/Rollout': time.time() - roll_out_time}) update_time = time.time() self._update() - self._logger.store(**{'Time/Update': time.time() - update_time}) + self._logger.store({'Time/Update': time.time() - update_time}) if self._cfgs.model_cfgs.exploration_noise_anneal: self._actor_critic.annealing(epoch) @@ -269,7 +269,7 @@ def learn(self) -> tuple[int | float, ...]: self._actor_critic.actor_scheduler.step() self._logger.store( - **{ + { 'TotalEnvSteps': (epoch + 1) * self._cfgs.algo_cfgs.steps_per_epoch, 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), @@ -390,7 +390,7 @@ def _update(self) -> None: break self._logger.store( - **{ + { 'Train/StopIter': update_counts, # pylint: disable=undefined-loop-variable 'Value/Adv': adv_r.mean().item(), 'Train/KL': final_kl, @@ -434,7 +434,7 @@ def _update_reward_critic(self, obs: torch.Tensor, target_value_r: torch.Tensor) distributed.avg_grads(self._actor_critic.reward_critic) self._actor_critic.reward_critic_optimizer.step() - self._logger.store(**{'Loss/Loss_reward_critic': loss.mean().item()}) + self._logger.store({'Loss/Loss_reward_critic': loss.mean().item()}) def _update_cost_critic(self, obs: torch.Tensor, target_value_c: torch.Tensor) -> None: r"""Update value network under a double for loop. @@ -473,7 +473,7 @@ def _update_cost_critic(self, obs: torch.Tensor, target_value_c: torch.Tensor) - distributed.avg_grads(self._actor_critic.cost_critic) self._actor_critic.cost_critic_optimizer.step() - self._logger.store(**{'Loss/Loss_cost_critic': loss.mean().item()}) + self._logger.store({'Loss/Loss_cost_critic': loss.mean().item()}) def _update_actor( # pylint: disable=too-many-arguments self, @@ -515,7 +515,7 @@ def _update_actor( # pylint: disable=too-many-arguments distributed.avg_grads(self._actor_critic.actor) self._actor_critic.actor_optimizer.step() self._logger.store( - **{ + { 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], 'Train/PolicyStd': info['std'], diff --git a/omnisafe/algorithms/on_policy/base/trpo.py b/omnisafe/algorithms/on_policy/base/trpo.py index 776fd2e29..33a52f504 100644 --- a/omnisafe/algorithms/on_policy/base/trpo.py +++ b/omnisafe/algorithms/on_policy/base/trpo.py @@ -128,7 +128,7 @@ def _search_step_size( set_param_values_to_model(self._actor_critic.actor, theta_old) self._logger.store( - **{ + { 'Train/KL': final_kl, }, ) @@ -199,7 +199,7 @@ def _update_actor( # pylint: disable=too-many-arguments,too-many-locals loss, info = self._loss_pi(obs, act, logp, adv) self._logger.store( - **{ + { 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], 'Train/PolicyStd': info['std'], diff --git a/omnisafe/algorithms/on_policy/first_order/cup.py b/omnisafe/algorithms/on_policy/first_order/cup.py index 1c1d2963e..634c0d10d 100644 --- a/omnisafe/algorithms/on_policy/first_order/cup.py +++ b/omnisafe/algorithms/on_policy/first_order/cup.py @@ -117,7 +117,7 @@ def _loss_pi_cost(self, obs, act, logp, adv_c): entropy = distribution.entropy().mean().item() info = {'entropy': entropy, 'ratio': ratio.mean().item(), 'std': std} - self._logger.store(**{'Loss/Loss_pi_c': loss.item()}) + self._logger.store({'Loss/Loss_pi_c': loss.item()}) return loss, info @@ -195,7 +195,7 @@ def _update(self) -> None: break self._logger.store( - **{ + { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.item(), 'Train/SecondStepStopIter': i + 1, # pylint: disable=undefined-loop-variable 'Train/SecondStepEntropy': info['entropy'], diff --git a/omnisafe/algorithms/on_policy/first_order/focops.py b/omnisafe/algorithms/on_policy/first_order/focops.py index db24921fb..0b83ccabb 100644 --- a/omnisafe/algorithms/on_policy/first_order/focops.py +++ b/omnisafe/algorithms/on_policy/first_order/focops.py @@ -219,7 +219,7 @@ def _update(self) -> None: break self._logger.store( - **{ + { 'Train/StopIter': i + 1, # pylint: disable=undefined-loop-variable 'Value/Adv': adv_r.mean().item(), 'Train/KL': kl, diff --git a/omnisafe/algorithms/on_policy/naive_lagrange/pdo.py b/omnisafe/algorithms/on_policy/naive_lagrange/pdo.py index 2fdff81c9..8071b949a 100644 --- a/omnisafe/algorithms/on_policy/naive_lagrange/pdo.py +++ b/omnisafe/algorithms/on_policy/naive_lagrange/pdo.py @@ -60,7 +60,7 @@ def _update(self) -> None: # then update the policy and value function super()._update() - self._logger.store(**{'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) + self._logger.store({'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: penalty = self._lagrange.lagrangian_multiplier.item() diff --git a/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py b/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py index 746062f03..9375e89b8 100644 --- a/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py +++ b/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py @@ -78,7 +78,7 @@ def _update(self) -> None: # then update the policy and value function super()._update() - self._logger.store(**{'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) + self._logger.store({'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: r"""Compute surrogate loss. diff --git a/omnisafe/algorithms/on_policy/naive_lagrange/rcpo.py b/omnisafe/algorithms/on_policy/naive_lagrange/rcpo.py index 2e00bd4ba..79f3e065b 100644 --- a/omnisafe/algorithms/on_policy/naive_lagrange/rcpo.py +++ b/omnisafe/algorithms/on_policy/naive_lagrange/rcpo.py @@ -63,7 +63,7 @@ def _update(self) -> None: # then update the policy and value function super()._update() - self._logger.store(**{'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) + self._logger.store({'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: penalty = self._lagrange.lagrangian_multiplier.item() diff --git a/omnisafe/algorithms/on_policy/naive_lagrange/trpo_lag.py b/omnisafe/algorithms/on_policy/naive_lagrange/trpo_lag.py index 07fb5e737..f707e8acf 100644 --- a/omnisafe/algorithms/on_policy/naive_lagrange/trpo_lag.py +++ b/omnisafe/algorithms/on_policy/naive_lagrange/trpo_lag.py @@ -76,7 +76,7 @@ def _update(self) -> None: # then update the policy and value function super()._update() - self._logger.store(**{'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) + self._logger.store({'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: r"""Compute surrogate loss. diff --git a/omnisafe/algorithms/on_policy/penalty_function/ipo.py b/omnisafe/algorithms/on_policy/penalty_function/ipo.py index 76ba5e558..2a4d784ad 100644 --- a/omnisafe/algorithms/on_policy/penalty_function/ipo.py +++ b/omnisafe/algorithms/on_policy/penalty_function/ipo.py @@ -67,6 +67,6 @@ def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> to if penalty < 0 or penalty > self._cfgs.algo_cfgs.penalty_max: penalty = self._cfgs.algo_cfgs.penalty_max - self._logger.store(**{'Misc/Penalty': penalty}) + self._logger.store({'Misc/Penalty': penalty}) return (adv_r - penalty * adv_c) / (1 + penalty) diff --git a/omnisafe/algorithms/on_policy/penalty_function/p3o.py b/omnisafe/algorithms/on_policy/penalty_function/p3o.py index 78b1e654c..e8c08c60c 100644 --- a/omnisafe/algorithms/on_policy/penalty_function/p3o.py +++ b/omnisafe/algorithms/on_policy/penalty_function/p3o.py @@ -129,7 +129,7 @@ def _update_actor( self._actor_critic.actor_optimizer.step() self._logger.store( - **{ + { 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], 'Train/PolicyStd': info['std'], diff --git a/omnisafe/algorithms/on_policy/pid_lagrange/cppo_pid.py b/omnisafe/algorithms/on_policy/pid_lagrange/cppo_pid.py index df82dde2f..9fa65cf04 100644 --- a/omnisafe/algorithms/on_policy/pid_lagrange/cppo_pid.py +++ b/omnisafe/algorithms/on_policy/pid_lagrange/cppo_pid.py @@ -79,7 +79,7 @@ def _update(self) -> None: # then update the policy and value function super()._update() - self._logger.store(**{'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) + self._logger.store({'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: r"""Compute surrogate loss. diff --git a/omnisafe/algorithms/on_policy/pid_lagrange/trpo_pid.py b/omnisafe/algorithms/on_policy/pid_lagrange/trpo_pid.py index 7170c79da..36912380e 100644 --- a/omnisafe/algorithms/on_policy/pid_lagrange/trpo_pid.py +++ b/omnisafe/algorithms/on_policy/pid_lagrange/trpo_pid.py @@ -76,7 +76,7 @@ def _update(self) -> None: # then update the policy and value function super()._update() - self._logger.store(**{'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) + self._logger.store({'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier}) def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: r"""Compute surrogate loss. diff --git a/omnisafe/algorithms/on_policy/second_order/cpo.py b/omnisafe/algorithms/on_policy/second_order/cpo.py index 3ace9dc62..f81520686 100644 --- a/omnisafe/algorithms/on_policy/second_order/cpo.py +++ b/omnisafe/algorithms/on_policy/second_order/cpo.py @@ -166,7 +166,7 @@ def _cpo_search_step( acceptance_step = 0 self._logger.store( - **{ + { 'Train/KL': kl, }, ) @@ -413,7 +413,7 @@ def _update_actor( loss = loss_reward + loss_cost self._logger.store( - **{ + { 'Loss/Loss_pi': loss.item(), 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], diff --git a/omnisafe/algorithms/on_policy/second_order/pcpo.py b/omnisafe/algorithms/on_policy/second_order/pcpo.py index 890e803ee..6954d83bb 100644 --- a/omnisafe/algorithms/on_policy/second_order/pcpo.py +++ b/omnisafe/algorithms/on_policy/second_order/pcpo.py @@ -132,7 +132,7 @@ def _update_actor( loss = loss_reward + loss_cost self._logger.store( - **{ + { 'Loss/Loss_pi': loss.item(), 'Train/Entropy': info['entropy'], 'Train/PolicyRatio': info['ratio'], diff --git a/omnisafe/common/logger.py b/omnisafe/common/logger.py index 39e7f3dc3..097e526a2 100644 --- a/omnisafe/common/logger.py +++ b/omnisafe/common/logger.py @@ -202,7 +202,7 @@ def register_key( The logger can record the following data: - .. code-block:: bash + .. code-block:: text ---------------------------------------------------- | Name | Value | @@ -249,12 +249,19 @@ def register_key( self._data[key] = [] self._headers_windows[key] = None - def store(self, **kwargs: int | float | np.ndarray | torch.Tensor) -> None: + def store( + self, + data: dict[str, int | float | np.ndarray | torch.Tensor] | None = None, + /, + **kwargs: int | float | np.ndarray | torch.Tensor, + ) -> None: """Store the data to the logger. Args: **kwargs: The data to be stored. """ + if data is not None: + kwargs.update(data) for key, val in kwargs.items(): assert key in self._current_row, f'Key {key} has not been registered' if isinstance(val, (int, float)):