From edd452ba923f248f9d3e095ee2bc2a5b28b90d85 Mon Sep 17 00:00:00 2001 From: zmsn-2077 <73586554+zmsn-2077@users.noreply.github.com> Date: Sun, 18 Dec 2022 21:31:21 +0800 Subject: [PATCH] refactor: open pylint in pre-commit (#48) Co-authored-by: Xuehai Pan --- .github/workflows/ci.yml | 4 +++ .pre-commit-config.yaml | 31 ++++++++++--------- docs/source/spelling_wordlist.txt | 13 ++++++++ .../safety_gymnasium/bases/__init__.py | 4 +-- .../bases/base_mujoco_task.py | 2 +- .../safety_gymnasium/bases/base_task.py | 2 +- .../tasks/button/button_level0.py | 2 +- .../tasks/goal/goal_level0.py | 2 +- .../tasks/push/push_level0.py | 2 +- omnisafe/algorithms/off_policy/ddpg.py | 18 +++++------ omnisafe/algorithms/on_policy/cup.py | 1 - omnisafe/common/lagrange.py | 1 + tests/test_policy.py | 1 + 13 files changed, 51 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0e9bb5e7..d80edd7cc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,6 +40,10 @@ jobs: run: | python -m pip install -vvv --editable '.[lint]' + - name: Install safety_gymnasium + run: | + python -m pip install -vvv --editable 'envs/safety-gymnasium' + - name: pre-commit run: | python -m pre_commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index adede42af..9dd980dc7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,18 +28,19 @@ repos: hooks: - id: black-jupyter stages: [commit, push, manual] - # - repo: local - # hooks: - # - id: pylint - # name: pylint - # entry: pylint - # language: system - # types: [python] - # require_serial: true - # stages: [commit, push, manual] - # exclude: | - # (?x)( - # ^examples/| - # ^tests/| - # ^setup.py$ - # ) + - repo: local + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + require_serial: true + stages: [commit, push, manual] + exclude: | + (?x)( + ^examples/| + ^tests/| + ^setup.py$| + ^docs/source/conf.py$ + ) diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index c9be8e826..0ddf1661d 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -168,3 +168,16 @@ vel quaternion Quaternions Jacobian +Lillicrap +Erez +Yuval +Tassa +Jiaming +Ji +Juntao +Dai +Linrui +Binbin +Zhou +Pengfei +Yaodong diff --git a/envs/safety-gymnasium/safety_gymnasium/bases/__init__.py b/envs/safety-gymnasium/safety_gymnasium/bases/__init__.py index 7a4b32874..79f07d605 100644 --- a/envs/safety-gymnasium/safety_gymnasium/bases/__init__.py +++ b/envs/safety-gymnasium/safety_gymnasium/bases/__init__.py @@ -14,5 +14,5 @@ # ============================================================================== """Base classes.""" -from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask -from safety_gymnasium.bases.base_task import BaseTask +# from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask +# from safety_gymnasium.bases.base_task import BaseTask diff --git a/envs/safety-gymnasium/safety_gymnasium/bases/base_mujoco_task.py b/envs/safety-gymnasium/safety_gymnasium/bases/base_mujoco_task.py index e11b604ac..026d48615 100644 --- a/envs/safety-gymnasium/safety_gymnasium/bases/base_mujoco_task.py +++ b/envs/safety-gymnasium/safety_gymnasium/bases/base_mujoco_task.py @@ -18,7 +18,7 @@ from copy import deepcopy from typing import Union -import gymnasium +# import gymnasium import mujoco import numpy as np from gymnasium.envs.mujoco.mujoco_rendering import RenderContextOffscreen, Viewer diff --git a/envs/safety-gymnasium/safety_gymnasium/bases/base_task.py b/envs/safety-gymnasium/safety_gymnasium/bases/base_task.py index 9256555d3..c6e0254e3 100644 --- a/envs/safety-gymnasium/safety_gymnasium/bases/base_task.py +++ b/envs/safety-gymnasium/safety_gymnasium/bases/base_task.py @@ -25,7 +25,7 @@ from safety_gymnasium.assets.mocaps import MOCAPS_REGISTER from safety_gymnasium.assets.objects import OBJS_REGISTER from safety_gymnasium.assets.robot import Robot -from safety_gymnasium.bases import BaseMujocoTask +from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask from safety_gymnasium.utils.common_utils import ResamplingError from safety_gymnasium.utils.task_utils import quat2mat, theta2vec diff --git a/envs/safety-gymnasium/safety_gymnasium/tasks/button/button_level0.py b/envs/safety-gymnasium/safety_gymnasium/tasks/button/button_level0.py index 273fbb72b..d77084638 100644 --- a/envs/safety-gymnasium/safety_gymnasium/tasks/button/button_level0.py +++ b/envs/safety-gymnasium/safety_gymnasium/tasks/button/button_level0.py @@ -21,7 +21,7 @@ import numpy as np from safety_gymnasium.assets.geoms import Buttons from safety_gymnasium.assets.group import GROUP -from safety_gymnasium.bases import BaseTask +from safety_gymnasium.bases.base_task import BaseTask # pylint: disable-next=too-many-instance-attributes diff --git a/envs/safety-gymnasium/safety_gymnasium/tasks/goal/goal_level0.py b/envs/safety-gymnasium/safety_gymnasium/tasks/goal/goal_level0.py index 396ca21de..284b4fdb7 100644 --- a/envs/safety-gymnasium/safety_gymnasium/tasks/goal/goal_level0.py +++ b/envs/safety-gymnasium/safety_gymnasium/tasks/goal/goal_level0.py @@ -15,7 +15,7 @@ """Goal level 0.""" from safety_gymnasium.assets.geoms import Goal -from safety_gymnasium.bases import BaseTask +from safety_gymnasium.bases.base_task import BaseTask class GoalLevel0(BaseTask): diff --git a/envs/safety-gymnasium/safety_gymnasium/tasks/push/push_level0.py b/envs/safety-gymnasium/safety_gymnasium/tasks/push/push_level0.py index bd9f2c13c..8c39483bb 100644 --- a/envs/safety-gymnasium/safety_gymnasium/tasks/push/push_level0.py +++ b/envs/safety-gymnasium/safety_gymnasium/tasks/push/push_level0.py @@ -17,7 +17,7 @@ import numpy as np from safety_gymnasium.assets.geoms import Goal from safety_gymnasium.assets.objects import PushBox -from safety_gymnasium.bases import BaseTask +from safety_gymnasium.bases.base_task import BaseTask class PushLevel0(BaseTask): diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py index 16f65cabd..58ab15f12 100644 --- a/omnisafe/algorithms/off_policy/ddpg.py +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -30,12 +30,13 @@ @registry.register -class DDPG: +class DDPG: # pylint: disable=too-many-instance-attributes """Continuous control with deep reinforcement learning (DDPG) Algorithm. References: Paper Name: Continuous control with deep reinforcement learning. - Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra. + Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, + Tom Erez, Yuval Tassa, David Silver, Daan Wierstra. Paper URL: https://arxiv.org/abs/1509.02971 """ @@ -96,7 +97,6 @@ def __init__( self.actor_critic = ConstraintActorQCritic( observation_space=self.env.observation_space, action_space=self.env.action_space, - scale_rewards=cfgs.scale_rewards, standardized_obs=cfgs.standardized_obs, model_cfgs=cfgs.model_cfgs, ) @@ -223,16 +223,16 @@ def compute_loss_v(self, data): data['obs_next'], data['done'], ) - q = self.actor_critic.critic(obs, act) + q_value = self.actor_critic.critic(obs, act) # Bellman backup for Q function with torch.no_grad(): act_targ, _ = self.ac_targ.actor.predict(obs, deterministic=True) q_targ = self.ac_targ.critic(obs_next, act_targ) backup = rew + self.cfgs.gamma * (1 - done) * q_targ # MSE loss against Bellman backup - loss_q = ((q - backup) ** 2).mean() + loss_q = ((q_value - backup) ** 2).mean() # Useful info for logging - q_info = dict(Q1Vals=q.detach().numpy()) + q_info = dict(Q1Vals=q_value.detach().numpy()) return loss_q, q_info def compute_loss_c(self, data): @@ -249,7 +249,7 @@ def compute_loss_c(self, data): data['obs_next'], data['done'], ) - qc = self.actor_critic.cost_critic(obs, act) + cost_q_value = self.actor_critic.cost_critic(obs, act) # Bellman backup for Q function with torch.no_grad(): @@ -257,9 +257,9 @@ def compute_loss_c(self, data): qc_targ = self.ac_targ.c(obs_next, action) backup = cost + self.cfgs.gamma * (1 - done) * qc_targ # MSE loss against Bellman backup - loss_qc = ((qc - backup) ** 2).mean() + loss_qc = ((cost_q_value - backup) ** 2).mean() # Useful info for logging - qc_info = dict(QCosts=qc.detach().numpy()) + qc_info = dict(QCosts=cost_q_value.detach().numpy()) return loss_qc, qc_info diff --git a/omnisafe/algorithms/on_policy/cup.py b/omnisafe/algorithms/on_policy/cup.py index 721382d26..bbaa32541 100644 --- a/omnisafe/algorithms/on_policy/cup.py +++ b/omnisafe/algorithms/on_policy/cup.py @@ -152,7 +152,6 @@ def slice_data(self, data) -> dict: 'adv': adv[i * batch_size : (i + 1) * batch_size], 'discounted_ret': discounted_ret[i * batch_size : (i + 1) * batch_size], 'cost_adv': cost_adv[i * batch_size : (i + 1) * batch_size], - 'target_v': target_v[i * batch_size : (i + 1) * batch_size], } ) diff --git a/omnisafe/common/lagrange.py b/omnisafe/common/lagrange.py index e6f11f0ed..5beef9ed6 100644 --- a/omnisafe/common/lagrange.py +++ b/omnisafe/common/lagrange.py @@ -22,6 +22,7 @@ class Lagrange(abc.ABC): """Abstract base class for Lagrangian-base Algorithms.""" + # pylint: disable-next=too-many-arguments def __init__( self, cost_limit: float, diff --git a/tests/test_policy.py b/tests/test_policy.py index a0d329bbc..1d27864fb 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -32,6 +32,7 @@ 'PCPO', 'FOCOPS', 'CPPOPid', + 'CUP', ] ) def test_on_policy(algo):