Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metrics[7]: speedup binclf with numba [GSoC23 cont'd] #14

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ imgaug==0.4.0
jsonargparse[signatures]>=4.3
kornia>=0.6.6,<0.6.10
matplotlib>=3.4.3
numba==0.57.1
omegaconf>=2.1.1
opencv-python>=4.5.3.56
pandas>=1.1.0
pytorch-lightning>=1.7.0,<1.10.0
tbb>=2021.6.0
timm>=0.5.4,<=0.6.12
torchmetrics==0.10.3
130 changes: 130 additions & 0 deletions src/anomalib/utils/metrics/perimg/_binclf_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import numba
import numpy as np
from numpy import ndarray


@numba.jit(nopython=True)
def _binclf_curve_numba(scoremap: ndarray, mask: ndarray, thresholds: ndarray):
"""Compute the binary classification matrix curve of a single image for a given sequence of thresholds.

This does the same as `__binclf_curves_ndarray_itertools` but with numba using just-in-time compilation.

ATTENTION:
1. `thresholds` must be sorted in ascending order!
2. Argument validation is not done here!


Note: predicted as positive condition is `score >= th`.

Args:
D: number of pixels in each image
scoremap (ndarray): Anomaly score maps of shape (D,),
mask (ndarray): Binary (bool) ground truth mask of shape (D,),
thresholds (ndarray): Sequence of T thresholds to compute the binary classification matrix for.

Returns:
ndarray: Binary classification matrix of shape (T, 2, 2)
The last two dimensions are the confusion matrix for each threshold, organized as (true class, predicted class):
- `tps`: `[... , 1, 1]`
- `fps`: `[... , 0, 1]`
- `fns`: `[... , 1, 0]`
- `tns`: `[... , 0, 0]`

"""

num_th = len(thresholds)

# POSITIVES
scores_pos = scoremap[mask]
# the sorting is very important for the algorithm to work and the speedup
scores_pos = np.sort(scores_pos)
# start counting with lowest th, so everything is predicted as positive (this variable is updated in the loop)
num_pos = current_count_tp = len(scores_pos)

tps = np.empty((num_th,), dtype=np.int64)

# NEGATIVES
# same thing but for the negative samples
scores_neg = scoremap[~mask]
scores_neg = np.sort(scores_neg)
num_neg = current_count_fp = len(scores_neg)

fps = np.empty((num_th,), dtype=np.int64)

# it will progressively drop the scores that are below the current th
for thidx, th in enumerate(thresholds):
num_drop = 0
num_scores = len(scores_pos)
while num_drop < num_scores and scores_pos[num_drop] < th: # ! scores_pos !
num_drop += 1
# ---
scores_pos = scores_pos[num_drop:]
current_count_tp -= num_drop
tps[thidx] = current_count_tp

# same with the negatives
num_drop = 0
num_scores = len(scores_neg)
while num_drop < num_scores and scores_neg[num_drop] < th: # ! scores_neg !
num_drop += 1
# ---
scores_neg = scores_neg[num_drop:]
current_count_fp -= num_drop
fps[thidx] = current_count_fp

fns = num_pos * np.ones((num_th,), dtype=np.int64) - tps
tns = num_neg * np.ones((num_th,), dtype=np.int64) - fps

# sequence of dimensions is (thresholds, true class, predicted class)
# so `tps` is `confmat[:, 1, 1]`, `fps` is `confmat[:, 0, 1]`, etc.
return np.stack(
(
np.stack((tns, fps), axis=-1),
np.stack((fns, tps), axis=-1),
),
axis=-1,
).transpose(0, 2, 1)


@numba.jit(nopython=True, parallel=True)
def _binclf_curves_numba_parallel(scoremaps: ndarray, masks: ndarray, thresholds: ndarray):
"""Generalize the function above to a batch of images by parallelizing the loop over images.

This has the same role as

```
_binclf_curves_ndarray_itertools = np.vectorize(
__binclf_curves_ndarray_itertools,
signature="(n),(n),(k)->(k,2,2)",
)
```

but it leverages numba's parallelization.

ATTENTION:
1. `thresholds` must be sorted in ascending order!
2. Argument validation is not done here!

Args:
N: number of images
D: number of pixels in each image
scoremaps (ndarray): Anomaly score maps of shape (N, D),
masks (ndarray): Binary (bool) ground truth masks of shape (N, D),
thresholds (ndarray): Sequence of T thresholds to compute the binary classification matrix for.

Returns:
ndarray: Binary classification matrix of shape (N, T, 2, 2)
The last two dimensions are the confusion matrix for each threshold, organized as (true class, predicted class):
- `tps`: `[... , 1, 1]`
- `fps`: `[... , 0, 1]`
- `fns`: `[... , 1, 0]`
- `tns`: `[... , 0, 0]`
"""
num_imgs = scoremaps.shape[0]
num_th = len(thresholds)
ret = np.empty((num_imgs, num_th, 2, 2), dtype=np.int64)
for imgidx in numba.prange(num_imgs):
scoremap = scoremaps[imgidx]
mask = masks[imgidx]
ret[imgidx] = _binclf_curve_numba(scoremap, mask, thresholds)
return ret
21 changes: 18 additions & 3 deletions src/anomalib/utils/metrics/perimg/binclf_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from torchmetrics import Metric
from torchmetrics.utilities.data import dim_zero_cat

from ._binclf_numba import _binclf_curves_numba_parallel

# =========================================== ARGS VALIDATION ===========================================


Expand Down Expand Up @@ -261,6 +263,7 @@ def _perimg_binclf_curve_compute_cpu(
masks: Tensor,
threshold_bounds: Tensor | tuple[float, float],
num_thresholds: int,
algorithm: str = "numba-parallel",
):
"""Compute the binary classification matrix for a range of thresholds.

Expand All @@ -278,6 +281,8 @@ def _perimg_binclf_curve_compute_cpu(
masks (Tensor): Binary ground truth masks of shape (N, H, W)
threshold_bounds (Tensor | tuple[float, float]): Lower and upper bounds for the thresholds.
num_thresholds (int): Number of thresholds to compute between `threshold_bounds`.
algorithm (str): Algorithm to use for computing the binary classification matrix.
Options: `itertools`, `numba-parallel`.

Returns:
(Tensor, Tensor[int64]):
Expand Down Expand Up @@ -315,9 +320,19 @@ def _perimg_binclf_curve_compute_cpu(
)

# *** update() ***
binclf_curve_ndarray = _binclf_curves_ndarray_itertools(
anomaly_maps.numpy(), masks.numpy().astype(bool), thresholds.numpy()
)
if algorithm == "itertools":
binclf_curve_ndarray = _binclf_curves_ndarray_itertools(
anomaly_maps.numpy(), masks.numpy().astype(bool), thresholds.numpy()
)

elif algorithm == "numba-parallel":
binclf_curve_ndarray = _binclf_curves_numba_parallel(
anomaly_maps.numpy(), masks.numpy().astype(bool), thresholds.numpy()
)

else:
raise ValueError(f"Algorithm {algorithm} not recognized.")

return thresholds, torch.from_numpy(binclf_curve_ndarray).to(anomaly_maps.device).long()


Expand Down
6 changes: 3 additions & 3 deletions src/anomalib/utils/metrics/perimg/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ def append_record(stat_, val_):
records.append(
dict(
statistic=stat_,
value=val_,
nearest=nearest,
imgidx=imgidx,
value=float(val_),
nearest=float(nearest),
imgidx=int(imgidx),
)
)

Expand Down
56 changes: 49 additions & 7 deletions tests/pre_merge/utils/metrics/test_perimg/test_binclf_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch

from anomalib.utils.metrics.perimg._binclf_numba import _binclf_curve_numba, _binclf_curves_numba_parallel
from anomalib.utils.metrics.perimg.binclf_curve import (
PerImageBinClfCurve,
__binclf_curves_ndarray_itertools,
Expand Down Expand Up @@ -38,7 +39,7 @@ def pytest_generate_tests(metafunc):
axis=0,
).astype(int)

if metafunc.function is test___binclf_curves_ndarray_itertools:
if metafunc.function is test___binclf_curves_ndarray_itertools or metafunc.function is test__binclf_curve_numba:
metafunc.parametrize(
argnames=("pred", "mask", "thresholds", "expected"),
argvalues=[
Expand All @@ -52,7 +53,10 @@ def pytest_generate_tests(metafunc):
masks = np.stack([mask_anom, mask_norm], axis=0)
expecteds = np.stack([expected_anom, expected_norm], axis=0)

if metafunc.function is test__binclf_curves_ndarray_itertools:
if (
metafunc.function is test__binclf_curves_ndarray_itertools
or metafunc.function is test__binclf_curves_numba_parallel
):
metafunc.parametrize(
argnames=("preds", "masks", "thresholds", "expecteds"),
argvalues=[
Expand All @@ -68,13 +72,25 @@ def pytest_generate_tests(metafunc):

if metafunc.function is test__perimg_binclf_curve_compute_cpu:
metafunc.parametrize(
argnames=("preds", "masks", "expected_thresholds", "expecteds"),
argnames=(
"preds",
"masks",
"expected_thresholds",
"expecteds",
),
argvalues=[
(preds[:1], masks[:1], thresholds, expecteds[:1]),
(preds, masks, thresholds, expecteds),
(preds.reshape(2, 2, 2), masks.reshape(2, 2, 2), thresholds, expecteds),
],
)
metafunc.parametrize(
argnames=("algorithm",),
argvalues=[
("itertools",),
("numba-parallel",),
],
)

images_classes = torch.tensor([1, 0])

Expand All @@ -88,20 +104,42 @@ def pytest_generate_tests(metafunc):
)


# with double `_`
# ==================================================================================================
# itertools version


# with double `_` (for a single image)
def test___binclf_curves_ndarray_itertools(pred, mask, thresholds, expected):
computed = __binclf_curves_ndarray_itertools(pred, mask, thresholds)
assert computed.shape == (thresholds.size, 2, 2)
assert (computed == expected).all()


# with single `_`
# with single `_` (for a batch of images)
def test__binclf_curves_ndarray_itertools(preds, masks, thresholds, expecteds):
computed = _binclf_curves_ndarray_itertools(preds, masks, thresholds)
assert computed.shape == (preds.shape[0], thresholds.size, 2, 2)
assert (computed == expecteds).all()


# ==================================================================================================
# numba version


# for a single image
def test__binclf_curve_numba(pred, mask, thresholds, expected):
computed = _binclf_curve_numba(pred, mask, thresholds)
assert computed.shape == (thresholds.size, 2, 2)
assert (computed == expected).all()


# for a batch of images
def test__binclf_curves_numba_parallel(preds, masks, thresholds, expecteds):
computed = _binclf_curves_numba_parallel(preds, masks, thresholds)
assert computed.shape == (preds.shape[0], thresholds.size, 2, 2)
assert (computed == expecteds).all()


def test____binclf_curves_ndarray_itertools_validations():
# `pred` and `mask` must have the same length
with pytest.raises(ValueError):
Expand All @@ -116,11 +154,15 @@ def test____binclf_curves_ndarray_itertools_validations():
__binclf_curves_ndarray_itertools(np.arange(4), np.arange(4), np.arange(6).reshape(2, 3))


def test__perimg_binclf_curve_compute_cpu(preds, masks, expected_thresholds, expecteds):
def test__perimg_binclf_curve_compute_cpu(preds, masks, expected_thresholds, expecteds, algorithm):
th_bounds = torch.tensor((expected_thresholds[0], expected_thresholds[-1]))

computed_thresholds, computed = _perimg_binclf_curve_compute_cpu(
preds, masks, th_bounds, expected_thresholds.numel()
preds,
masks,
th_bounds,
expected_thresholds.numel(),
algorithm=algorithm,
)
assert computed.shape == expecteds.shape
assert (computed == expecteds).all()
Expand Down
Loading