Skip to content

Commit

Permalink
refactor: open pylint in pre-commit (#48)
Browse files Browse the repository at this point in the history
Co-authored-by: Xuehai Pan <[email protected]>
  • Loading branch information
zmsn-2077 and XuehaiPan committed Dec 22, 2022
1 parent 1ed496b commit 12fbca9
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 32 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ jobs:
run: |
python -m pip install -vvv --editable '.[lint]'
- name: Install safety_gymnasium
run: |
python -m pip install -vvv --editable 'envs/safety-gymnasium'
- name: pre-commit
run: |
python -m pre_commit run --all-files
Expand Down
31 changes: 16 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ repos:
hooks:
- id: black-jupyter
stages: [commit, push, manual]
# - repo: local
# hooks:
# - id: pylint
# name: pylint
# entry: pylint
# language: system
# types: [python]
# require_serial: true
# stages: [commit, push, manual]
# exclude: |
# (?x)(
# ^examples/|
# ^tests/|
# ^setup.py$
# )
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
require_serial: true
stages: [commit, push, manual]
exclude: |
(?x)(
^examples/|
^tests/|
^setup.py$|
^docs/source/conf.py$
)
13 changes: 13 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,16 @@ vel
quaternion
Quaternions
Jacobian
Lillicrap
Erez
Yuval
Tassa
Jiaming
Ji
Juntao
Dai
Linrui
Binbin
Zhou
Pengfei
Yaodong
4 changes: 2 additions & 2 deletions envs/safety-gymnasium/safety_gymnasium/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
# ==============================================================================
"""Base classes."""

from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask
from safety_gymnasium.bases.base_task import BaseTask
# from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask
# from safety_gymnasium.bases.base_task import BaseTask
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Union

import gymnasium
# import gymnasium
import mujoco
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import RenderContextOffscreen, Viewer
Expand Down
2 changes: 1 addition & 1 deletion envs/safety-gymnasium/safety_gymnasium/bases/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from safety_gymnasium.assets.mocaps import MOCAPS_REGISTER
from safety_gymnasium.assets.objects import OBJS_REGISTER
from safety_gymnasium.assets.robot import Robot
from safety_gymnasium.bases import BaseMujocoTask
from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask
from safety_gymnasium.utils.common_utils import ResamplingError
from safety_gymnasium.utils.task_utils import quat2mat, theta2vec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
from safety_gymnasium.assets.geoms import Buttons
from safety_gymnasium.assets.group import GROUP
from safety_gymnasium.bases import BaseTask
from safety_gymnasium.bases.base_task import BaseTask


# pylint: disable-next=too-many-instance-attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Goal level 0."""

from safety_gymnasium.assets.geoms import Goal
from safety_gymnasium.bases import BaseTask
from safety_gymnasium.bases.base_task import BaseTask


class GoalLevel0(BaseTask):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from safety_gymnasium.assets.geoms import Goal
from safety_gymnasium.assets.objects import PushBox
from safety_gymnasium.bases import BaseTask
from safety_gymnasium.bases.base_task import BaseTask


class PushLevel0(BaseTask):
Expand Down
18 changes: 9 additions & 9 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@


@registry.register
class DDPG:
class DDPG: # pylint: disable=too-many-instance-attributes
"""Continuous control with deep reinforcement learning (DDPG) Algorithm.
References:
Paper Name: Continuous control with deep reinforcement learning.
Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra.
Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess,
Tom Erez, Yuval Tassa, David Silver, Daan Wierstra.
Paper URL: https://arxiv.org/abs/1509.02971
"""
Expand Down Expand Up @@ -96,7 +97,6 @@ def __init__(
self.actor_critic = ConstraintActorQCritic(
observation_space=self.env.observation_space,
action_space=self.env.action_space,
scale_rewards=cfgs.scale_rewards,
standardized_obs=cfgs.standardized_obs,
model_cfgs=cfgs.model_cfgs,
)
Expand Down Expand Up @@ -223,16 +223,16 @@ def compute_loss_v(self, data):
data['obs_next'],
data['done'],
)
q = self.actor_critic.critic(obs, act)
q_value = self.actor_critic.critic(obs, act)
# Bellman backup for Q function
with torch.no_grad():
act_targ, _ = self.ac_targ.actor.predict(obs, deterministic=True)
q_targ = self.ac_targ.critic(obs_next, act_targ)
backup = rew + self.cfgs.gamma * (1 - done) * q_targ
# MSE loss against Bellman backup
loss_q = ((q - backup) ** 2).mean()
loss_q = ((q_value - backup) ** 2).mean()
# Useful info for logging
q_info = dict(Q1Vals=q.detach().numpy())
q_info = dict(Q1Vals=q_value.detach().numpy())
return loss_q, q_info

def compute_loss_c(self, data):
Expand All @@ -249,17 +249,17 @@ def compute_loss_c(self, data):
data['obs_next'],
data['done'],
)
qc = self.actor_critic.cost_critic(obs, act)
cost_q_value = self.actor_critic.cost_critic(obs, act)

# Bellman backup for Q function
with torch.no_grad():
action, _ = self.ac_targ.pi.predict(obs_next, deterministic=True)
qc_targ = self.ac_targ.c(obs_next, action)
backup = cost + self.cfgs.gamma * (1 - done) * qc_targ
# MSE loss against Bellman backup
loss_qc = ((qc - backup) ** 2).mean()
loss_qc = ((cost_q_value - backup) ** 2).mean()
# Useful info for logging
qc_info = dict(QCosts=qc.detach().numpy())
qc_info = dict(QCosts=cost_q_value.detach().numpy())

return loss_qc, qc_info

Expand Down
1 change: 0 additions & 1 deletion omnisafe/algorithms/on_policy/cup.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def slice_data(self, data) -> dict:
'adv': adv[i * batch_size : (i + 1) * batch_size],
'discounted_ret': discounted_ret[i * batch_size : (i + 1) * batch_size],
'cost_adv': cost_adv[i * batch_size : (i + 1) * batch_size],
'target_v': target_v[i * batch_size : (i + 1) * batch_size],
}
)

Expand Down
1 change: 1 addition & 0 deletions omnisafe/common/lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class Lagrange(abc.ABC):
"""Abstract base class for Lagrangian-base Algorithms."""

# pylint: disable-next=too-many-arguments
def __init__(
self,
cost_limit: float,
Expand Down
1 change: 1 addition & 0 deletions tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'PCPO',
'FOCOPS',
'CPPOPid',
'CUP',
]
)
def test_on_policy(algo):
Expand Down

0 comments on commit 12fbca9

Please sign in to comment.