Skip to content

Commit

Permalink
style: Bumps ruff to 0.1.9 and mypy to 1.8.0 (#226)
Browse files Browse the repository at this point in the history
* chore: Bumps ruff and mypy

* chore: Updates precommit

* chore: Updates gitignore

* style: Fixes lint

* style: Fixes typing

* ci: Updates CI jobs
  • Loading branch information
frgfm authored Dec 22, 2023
1 parent a4e5f2e commit b59be23
Show file tree
Hide file tree
Showing 22 changed files with 250 additions and 173 deletions.
9 changes: 3 additions & 6 deletions .github/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,7 @@ def get_os(run_lambda):
def get_env_info():
run_lambda = run

if TORCHCAM_AVAILABLE:
torchcam_str = torchcam.__version__
else:
torchcam_str = "N/A"
torchcam_str = torchcam.__version__ if TORCHCAM_AVAILABLE else "N/A"

if TORCH_AVAILABLE:
torch_str = torch.__version__
Expand Down Expand Up @@ -258,14 +255,14 @@ def get_env_info():

def pretty_str(envinfo):
def replace_nones(dct, replacement="Could not collect"):
for key in dct.keys():
for key in dct:
if dct[key] is not None:
continue
dct[key] = replacement
return dct

def replace_bools(dct, true="Yes", false="No"):
for key in dct.keys():
for key in dct:
if dct[key] is True:
dct[key] = true
elif dct[key] is False:
Expand Down
3 changes: 2 additions & 1 deletion .github/verify_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def parse_args():
import argparse

parser = argparse.ArgumentParser(
description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="PR label checker",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument("pr", type=int, help="PR number")
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Miniconda setup
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
python-version: 3.9
Expand Down
40 changes: 9 additions & 31 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ jobs:
python: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Run ruff
run: |
pip install ruff==0.1.0
pip install ruff==0.1.9
ruff --version
ruff check --diff .
Expand All @@ -34,8 +33,7 @@ jobs:
python: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
architecture: x64
Expand All @@ -53,40 +51,20 @@ jobs:
mypy --version
mypy
black:
ruff-format:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Run black
run: |
pip install "black==23.3.0"
black --version
black --check --diff .
bandit:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Run bandit
- name: Run ruff
run: |
pip install bandit[toml]
bandit --version
bandit -r . -c pyproject.toml
pip install ruff==0.1.9
ruff --version
ruff format --check --diff .
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ torchcam/version.py

# Conda distribution
conda-dist/
.vscode/
7 changes: 2 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ repos:
args: ['--branch', 'main']
- id: debug-statements
language_version: python3
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.290'
rev: 'v0.1.9'
hooks:
- id: ruff
args:
- --fix
- id: ruff-format
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# this target runs checks on all files
quality:
ruff format --check .
ruff check .
mypy
black --check .
bandit -r . -c pyproject.toml

# this target runs checks on all files and potentially modifies some of them
style:
black .
ruff format .
ruff --fix .

# Run tests for the library
Expand Down
21 changes: 18 additions & 3 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@
from torchcam.methods._utils import locate_candidate_layer
from torchcam.utils import overlay_mask

CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
CAM_METHODS = [
"CAM",
"GradCAM",
"GradCAMpp",
"SmoothGradCAMpp",
"ScoreCAM",
"SSCAM",
"ISCAM",
"XGradCAM",
"LayerCAM",
]
TV_MODELS = [
"resnet18",
"resnet50",
Expand Down Expand Up @@ -87,7 +97,8 @@ def main():
)
if cam_method is not None:
cam_extractor = methods.__dict__[cam_method](
model, target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None
model,
target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None,
)

class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
Expand All @@ -103,7 +114,11 @@ def main():
else:
with st.spinner("Analyzing..."):
# Preprocess image
img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img_tensor = normalize(
to_tensor(resize(img, (224, 224))),
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

if torch.cuda.is_available():
img_tensor = img_tensor.cuda()
Expand Down
6 changes: 2 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import datetime
from pathlib import Path

sys.path.insert(0, Path().resolve().parent.parent)
sys.path.insert(0, Path().cwd().parent.parent)
import torchcam

# -- Project information -----------------------------------------------------
Expand Down Expand Up @@ -121,9 +121,7 @@ def add_ga_javascript(app, pagename, templatename, context, doctree):
gtag('js', new Date());
gtag('config', '{0}');
</script>
""".format(
app.config.googleanalytics_id
)
""".format(app.config.googleanalytics_id)
context["metatags"] = metatags


Expand Down
56 changes: 33 additions & 23 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@ test = [
"pytest-pretty>=1.0.0,<2.0.0",
]
quality = [
"ruff==0.1.0",
"mypy==1.5.1",
"black==23.3.0",
"bandit[toml]>=1.7.0,<1.8.0",
"pre-commit>=2.17.0,<3.0.0",
"ruff==0.1.9",
"mypy==1.8.0",
"pre-commit>=3.0.0,<4.0.0",
]
docs = [
"sphinx>=3.0.0,!=3.5.0",
Expand All @@ -80,11 +78,9 @@ dev = [
"pytest-xdist>=3.0.0,<4.0.0",
"pytest-pretty>=1.0.0,<2.0.0",
# style
"ruff==0.1.0",
"mypy==1.5.1",
"black==23.3.0",
"bandit[toml]>=1.7.0,<1.8.0",
"pre-commit>=2.17.0,<3.0.0",
"ruff==0.1.9",
"mypy==1.8.0",
"pre-commit>=3.0.0,<4.0.0",
# docs
"sphinx>=3.0.0,!=3.5.0",
"furo>=2022.3.4",
Expand Down Expand Up @@ -133,6 +129,16 @@ select = [
"T20", # flake8-print
"PT", # flake8-pytest-style
"LOG", # flake8-logging
"SIM", # flake8-simplify
"YTT", # flake8-2020
"ANN", # flake8-annotations
"ASYNC", # flake8-async
"BLE", # flake8-blind-except
"A", # flake8-builtins
"ICN", # flake8-import-conventions
"PIE", # flake8-pie
"ARG", # flake8-unused-arguments
"FURB", # refurb
]
ignore = [
"E501", # line too long, handled by black
Expand All @@ -142,20 +148,31 @@ ignore = [
"F403", # star imports
"E731", # lambda assignment
"C416", # list comprehension to list()
"ANN101", # missing type annotations on self
"ANN102", # missing type annotations on cls
"ANN002", # missing type annotations on *args
"ANN003", # missing type annotations on **kwargs
"COM812", # trailing comma missing
"N812", # lowercase imported as non-lowercase
"ISC001", # implicit string concatenation (handled by format)
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
]
exclude = [".git"]
line-length = 120
target-version = "py39"
preview = true

[tool.ruff.format]
quote-style = "double"
indent-style = "space"

[tool.ruff.per-file-ignores]
"**/__init__.py" = ["I001", "F401", "CPY001"]
"scripts/**.py" = ["D", "T201", "N812"]
".github/**.py" = ["D", "T201", "S602"]
"docs/**.py" = ["E402", "D103"]
"tests/**.py" = ["D103", "CPY001", "S101", "PT011",]
"demo/**.py" = ["D103"]
"scripts/**.py" = ["D", "T201", "N812", "S101", "ANN"]
".github/**.py" = ["D", "T201", "S602", "S101", "ANN"]
"docs/**.py" = ["E402", "D103", "ANN", "A001", "ARG001"]
"tests/**.py" = ["D103", "CPY001", "S101", "PT011", "ANN"]
"demo/**.py" = ["D103", "ANN"]
"setup.py" = ["T201"]

[tool.ruff.flake8-quotes]
Expand All @@ -177,18 +194,11 @@ no_implicit_optional = true
check_untyped_defs = true
implicit_reexport = false
disallow_untyped_defs = true
explicit_package_bases = true

[[tool.mypy.overrides]]
module = [
"PIL",
"matplotlib"
]
ignore_missing_imports = true

[tool.black]
line-length = 120
target-version = ['py39']

[tool.bandit]
exclude_dirs = [".github/collect_env.py"]
skips = ["B101"]
30 changes: 20 additions & 10 deletions scripts/cam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ def main(args):
p.requires_grad_(False)

# Image
if args.img.startswith("http"):
img_path = BytesIO(requests.get(args.img, timeout=5).content)
else:
img_path = args.img
img_path = BytesIO(requests.get(args.img, timeout=5).content) if args.img.startswith("http") else args.img
pil_img = Image.open(img_path, mode="r").convert("RGB")

# Preprocess image
img_tensor = normalize(to_tensor(resize(pil_img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).to(
device=device
)
img_tensor = normalize(
to_tensor(resize(pil_img, (224, 224))),
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
).to(device=device)
img_tensor.requires_grad_(True)

if isinstance(args.method, str):
Expand Down Expand Up @@ -119,7 +118,8 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Saliency Map comparison", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="Saliency Map comparison",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--arch", type=str, default="resnet18", help="Name of the architecture")
parser.add_argument(
Expand All @@ -129,13 +129,23 @@ def main(args):
help="The image to extract CAM from",
)
parser.add_argument("--class-idx", type=int, default=232, help="Index of the class to inspect")
parser.add_argument("--device", type=str, default=None, help="Default device to perform computation on")
parser.add_argument(
"--device",
type=str,
default=None,
help="Default device to perform computation on",
)
parser.add_argument("--savefig", type=str, default=None, help="Path to save figure")
parser.add_argument("--method", type=str, default=None, help="CAM method to use")
parser.add_argument("--target", type=str, default=None, help="the target layer")
parser.add_argument("--alpha", type=float, default=0.5, help="Transparency of the heatmap")
parser.add_argument("--rows", type=int, default=1, help="Number of rows for the layout")
parser.add_argument("--noblock", dest="noblock", help="Disables blocking visualization", action="store_true")
parser.add_argument(
"--noblock",
dest="noblock",
help="Disables blocking visualization",
action="store_true",
)
args = parser.parse_args()

main(args)
Loading

0 comments on commit b59be23

Please sign in to comment.