Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: open pylint in pre-commit #48

Merged
merged 20 commits into from
Dec 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ jobs:
run: |
python -m pip install -vvv --editable '.[lint]'
- name: Install safety_gymnasium
run: |
python -m pip install -vvv --editable 'envs/safety-gymnasium'
- name: pre-commit
run: |
python -m pre_commit run --all-files
Expand Down
31 changes: 16 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ repos:
hooks:
- id: black-jupyter
stages: [commit, push, manual]
# - repo: local
# hooks:
# - id: pylint
# name: pylint
# entry: pylint
# language: system
# types: [python]
# require_serial: true
# stages: [commit, push, manual]
# exclude: |
# (?x)(
# ^examples/|
# ^tests/|
# ^setup.py$
# )
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
require_serial: true
stages: [commit, push, manual]
exclude: |
(?x)(
^examples/|
^tests/|
^setup.py$|
^docs/source/conf.py$
)
13 changes: 13 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,16 @@ vel
quaternion
Quaternions
Jacobian
Lillicrap
Erez
Yuval
Tassa
Jiaming
Ji
Juntao
Dai
Linrui
Binbin
Zhou
Pengfei
Yaodong
4 changes: 2 additions & 2 deletions envs/safety-gymnasium/safety_gymnasium/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
# ==============================================================================
"""Base classes."""

from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask
from safety_gymnasium.bases.base_task import BaseTask
# from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask
# from safety_gymnasium.bases.base_task import BaseTask
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Union

import gymnasium
# import gymnasium
import mujoco
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import RenderContextOffscreen, Viewer
Expand Down
2 changes: 1 addition & 1 deletion envs/safety-gymnasium/safety_gymnasium/bases/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from safety_gymnasium.assets.mocaps import MOCAPS_REGISTER
from safety_gymnasium.assets.objects import OBJS_REGISTER
from safety_gymnasium.assets.robot import Robot
from safety_gymnasium.bases import BaseMujocoTask
from safety_gymnasium.bases.base_mujoco_task import BaseMujocoTask
from safety_gymnasium.utils.common_utils import ResamplingError
from safety_gymnasium.utils.task_utils import quat2mat, theta2vec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
from safety_gymnasium.assets.geoms import Buttons
from safety_gymnasium.assets.group import GROUP
from safety_gymnasium.bases import BaseTask
from safety_gymnasium.bases.base_task import BaseTask


# pylint: disable-next=too-many-instance-attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Goal level 0."""

from safety_gymnasium.assets.geoms import Goal
from safety_gymnasium.bases import BaseTask
from safety_gymnasium.bases.base_task import BaseTask


class GoalLevel0(BaseTask):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from safety_gymnasium.assets.geoms import Goal
from safety_gymnasium.assets.objects import PushBox
from safety_gymnasium.bases import BaseTask
from safety_gymnasium.bases.base_task import BaseTask


class PushLevel0(BaseTask):
Expand Down
18 changes: 9 additions & 9 deletions omnisafe/algorithms/off_policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@


@registry.register
class DDPG:
class DDPG: # pylint: disable=too-many-instance-attributes
"""Continuous control with deep reinforcement learning (DDPG) Algorithm.
References:
Paper Name: Continuous control with deep reinforcement learning.
Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra.
Paper author: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess,
Tom Erez, Yuval Tassa, David Silver, Daan Wierstra.
Paper URL: https://arxiv.org/abs/1509.02971
"""
Expand Down Expand Up @@ -96,7 +97,6 @@ def __init__(
self.actor_critic = ConstraintActorQCritic(
observation_space=self.env.observation_space,
action_space=self.env.action_space,
scale_rewards=cfgs.scale_rewards,
standardized_obs=cfgs.standardized_obs,
model_cfgs=cfgs.model_cfgs,
)
Expand Down Expand Up @@ -223,16 +223,16 @@ def compute_loss_v(self, data):
data['obs_next'],
data['done'],
)
q = self.actor_critic.critic(obs, act)
q_value = self.actor_critic.critic(obs, act)
# Bellman backup for Q function
with torch.no_grad():
act_targ, _ = self.ac_targ.actor.predict(obs, deterministic=True)
q_targ = self.ac_targ.critic(obs_next, act_targ)
backup = rew + self.cfgs.gamma * (1 - done) * q_targ
# MSE loss against Bellman backup
loss_q = ((q - backup) ** 2).mean()
loss_q = ((q_value - backup) ** 2).mean()
# Useful info for logging
q_info = dict(Q1Vals=q.detach().numpy())
q_info = dict(Q1Vals=q_value.detach().numpy())
return loss_q, q_info

def compute_loss_c(self, data):
Expand All @@ -249,17 +249,17 @@ def compute_loss_c(self, data):
data['obs_next'],
data['done'],
)
qc = self.actor_critic.cost_critic(obs, act)
cost_q_value = self.actor_critic.cost_critic(obs, act)

# Bellman backup for Q function
with torch.no_grad():
action, _ = self.ac_targ.pi.predict(obs_next, deterministic=True)
qc_targ = self.ac_targ.c(obs_next, action)
backup = cost + self.cfgs.gamma * (1 - done) * qc_targ
# MSE loss against Bellman backup
loss_qc = ((qc - backup) ** 2).mean()
loss_qc = ((cost_q_value - backup) ** 2).mean()
# Useful info for logging
qc_info = dict(QCosts=qc.detach().numpy())
qc_info = dict(QCosts=cost_q_value.detach().numpy())

return loss_qc, qc_info

Expand Down
1 change: 0 additions & 1 deletion omnisafe/algorithms/on_policy/cup.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def slice_data(self, data) -> dict:
'adv': adv[i * batch_size : (i + 1) * batch_size],
'discounted_ret': discounted_ret[i * batch_size : (i + 1) * batch_size],
'cost_adv': cost_adv[i * batch_size : (i + 1) * batch_size],
'target_v': target_v[i * batch_size : (i + 1) * batch_size],
}
)

Expand Down
1 change: 1 addition & 0 deletions omnisafe/common/lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class Lagrange(abc.ABC):
"""Abstract base class for Lagrangian-base Algorithms."""

# pylint: disable-next=too-many-arguments
def __init__(
self,
cost_limit: float,
Expand Down
1 change: 1 addition & 0 deletions tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'PCPO',
'FOCOPS',
'CPPOPid',
'CUP',
]
)
def test_on_policy(algo):
Expand Down