-
Notifications
You must be signed in to change notification settings - Fork 674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dataclasses and post-processing refactor #2098
Changes from 4 commits
68c5582
ddfcd5f
32e038d
675dd3f
5779ab7
0662558
fddbeb1
e32bd7d
90265e8
e3a9c1d
89f972c
08bdae2
2bc76fc
f7c7f9a
4160ab3
fd9eb24
87facb6
2269a78
9652b9f
082bbbc
dbabb20
b190cd3
ed904eb
773e54a
f8d999a
67046dd
d00b938
9fb4549
86cf632
fa3b874
2761600
c650dfc
ced34ca
12cd32d
b447cab
213c2b4
fb80feb
d2337a7
b53f1f7
5f16147
9203318
e99d630
631ba97
b750042
b37e265
86a365d
fcbb628
0fc3337
7ec9dd7
f5a48cd
afaec9b
e0a70c8
211d9f8
eb584eb
442c37f
bd59184
e17eda5
3140e8b
af99bed
039be2a
381e638
daead5b
987abe5
a709c6c
a37fa3b
14da4fa
beb3b97
a9d07db
6bcca36
014cb59
58df063
25845fb
1defdba
8d60276
0afb6d9
a26efb9
e7d9852
a4bcbfe
085c4aa
eff1f97
40bb4be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,6 +15,9 @@ | |||||
from anomalib import TaskType | ||||||
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection | ||||||
from anomalib.models import AnomalyModule | ||||||
from anomalib.dataclasses import BatchItem | ||||||
|
||||||
from dataclasses import asdict | ||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
|
@@ -121,7 +124,7 @@ def on_validation_batch_end( | |||||
del trainer, batch, batch_idx, dataloader_idx # Unused arguments. | ||||||
|
||||||
if outputs is not None: | ||||||
self._outputs_to_device(outputs) | ||||||
outputs = self._outputs_to_device(outputs) | ||||||
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) | ||||||
|
||||||
def on_validation_epoch_end( | ||||||
|
@@ -156,7 +159,7 @@ def on_test_batch_end( | |||||
del trainer, batch, batch_idx, dataloader_idx # Unused arguments. | ||||||
|
||||||
if outputs is not None: | ||||||
self._outputs_to_device(outputs) | ||||||
outputs = self._outputs_to_device(outputs) | ||||||
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) | ||||||
|
||||||
def on_test_epoch_end( | ||||||
|
@@ -179,15 +182,17 @@ def _update_metrics( | |||||
output: STEP_OUTPUT, | ||||||
) -> None: | ||||||
image_metric.to(self.device) | ||||||
image_metric.update(output["pred_scores"], output["label"].int()) | ||||||
if "mask" in output and "anomaly_maps" in output: | ||||||
image_metric.update(output.pred_score, output.gt_label.int()) | ||||||
if output.gt_mask is not None and output.anomaly_map is not None: | ||||||
pixel_metric.to(self.device) | ||||||
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) | ||||||
pixel_metric.update(torch.squeeze(output.anomaly_map), torch.squeeze(output.gt_mask.int())) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: | ||||||
if isinstance(output, dict): | ||||||
for key, value in output.items(): | ||||||
output[key] = self._outputs_to_device(value) | ||||||
elif isinstance(output, BatchItem): | ||||||
output = output.__class__(**self._outputs_to_device(asdict(output))) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be an idea to add a comment here? It might be difficult to understand for some readers |
||||||
elif isinstance(output, torch.Tensor): | ||||||
output = output.to(self.device) | ||||||
return output | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
|
||
from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes | ||
from anomalib.models import AnomalyModule | ||
from anomalib.dataclasses import BatchItem | ||
|
||
|
||
class _PostProcessorCallback(Callback): | ||
|
@@ -28,7 +29,7 @@ def on_validation_batch_end( | |
self, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this callback as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This module is deprecated with the new design. I left it here for legacy purposes until we decide how to handle backward compatibility, but you can ignore it for now. |
||
trainer: Trainer, | ||
pl_module: AnomalyModule, | ||
outputs: STEP_OUTPUT | None, | ||
outputs: BatchItem, | ||
batch: Any, # noqa: ANN401 | ||
batch_idx: int, | ||
dataloader_idx: int = 0, | ||
|
@@ -42,7 +43,7 @@ def on_test_batch_end( | |
self, | ||
trainer: Trainer, | ||
pl_module: AnomalyModule, | ||
outputs: STEP_OUTPUT | None, | ||
outputs: BatchItem, | ||
batch: Any, # noqa: ANN401 | ||
batch_idx: int, | ||
dataloader_idx: int = 0, | ||
|
@@ -56,7 +57,7 @@ def on_predict_batch_end( | |
self, | ||
trainer: Trainer, | ||
pl_module: AnomalyModule, | ||
outputs: Any, # noqa: ANN401 | ||
outputs: BatchItem, # noqa: ANN401 | ||
batch: Any, # noqa: ANN401 | ||
batch_idx: int, | ||
dataloader_idx: int = 0, | ||
|
@@ -67,7 +68,7 @@ def on_predict_batch_end( | |
self.post_process(trainer, pl_module, outputs) | ||
|
||
def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None: | ||
if isinstance(outputs, dict): | ||
if isinstance(outputs, BatchItem): | ||
self._post_process(outputs) | ||
if trainer.predicting or trainer.testing: | ||
self._compute_scores_and_labels(pl_module, outputs) | ||
|
@@ -77,49 +78,49 @@ def _compute_scores_and_labels( | |
pl_module: AnomalyModule, | ||
outputs: dict[str, Any], | ||
) -> None: | ||
if "pred_scores" in outputs: | ||
outputs["pred_labels"] = outputs["pred_scores"] >= pl_module.image_threshold.value | ||
if "anomaly_maps" in outputs: | ||
outputs["pred_masks"] = outputs["anomaly_maps"] >= pl_module.pixel_threshold.value | ||
if "pred_boxes" not in outputs: | ||
outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes( | ||
outputs["pred_masks"], | ||
outputs["anomaly_maps"], | ||
if outputs.pred_score is not None: | ||
outputs.pred_label = outputs.pred_score >= pl_module.image_threshold.value | ||
if outputs.anomaly_map is not None: | ||
outputs.pred_mask = outputs.anomaly_map >= pl_module.pixel_threshold.value | ||
if outputs.pred_boxes is None: | ||
djdameln marked this conversation as resolved.
Show resolved
Hide resolved
|
||
outputs.pred_boxes, outputs.box_scores = masks_to_boxes( | ||
outputs.pred_mask, | ||
outputs.anomaly_map, | ||
) | ||
outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]] | ||
outputs.box_labels = [torch.ones(boxes.shape[0]) for boxes in outputs.pred_boxes] | ||
# apply thresholding to boxes | ||
if "box_scores" in outputs and "box_labels" not in outputs: | ||
if outputs.box_scores is not None and outputs.box_labels is None: | ||
# apply threshold to assign normal/anomalous label to boxes | ||
is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs["box_scores"]] | ||
outputs["box_labels"] = [labels.int() for labels in is_anomalous] | ||
is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs.box_scores] | ||
outputs.box_labels = [labels.int() for labels in is_anomalous] | ||
|
||
@staticmethod | ||
def _post_process(outputs: STEP_OUTPUT) -> None: | ||
def _post_process(outputs: BatchItem) -> None: | ||
"""Compute labels based on model predictions.""" | ||
if isinstance(outputs, dict): | ||
if "pred_scores" not in outputs and "anomaly_maps" in outputs: | ||
if isinstance(outputs, BatchItem): | ||
if outputs.pred_score is None and outputs.anomaly_map is not None: | ||
# infer image scores from anomaly maps | ||
outputs["pred_scores"] = ( | ||
outputs["anomaly_maps"] # noqa: PD011 | ||
.reshape(outputs["anomaly_maps"].shape[0], -1) | ||
outputs.pred_score = ( | ||
outputs.anomaly_map # noqa: PD011 | ||
.reshape(outputs.anomaly_map.shape[0], -1) | ||
.max(dim=1) | ||
.values | ||
) | ||
elif "pred_scores" not in outputs and "box_scores" in outputs and "label" in outputs: | ||
elif outputs.pred_score is None and outputs.box_score is not None and outputs.gt_label is not None: | ||
# infer image score from bbox confidence scores | ||
outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float() | ||
for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"], strict=True)): | ||
outputs.pred_score = torch.zeros_like(outputs.gt_label).float() | ||
for idx, (boxes, scores) in enumerate(zip(outputs.pred_boxes, outputs.box_scores, strict=True)): | ||
if boxes.numel(): | ||
outputs["pred_scores"][idx] = scores.max().item() | ||
outputs.pred_score[idx] = scores.max().item() | ||
|
||
if "pred_boxes" in outputs and "anomaly_maps" not in outputs: | ||
if outputs.pred_boxes is not None and outputs.anomaly_map is None: | ||
samet-akcay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# create anomaly maps from bbox predictions for thresholding and evaluation | ||
image_size: tuple[int, int] = outputs["image"].shape[-2:] | ||
pred_boxes: torch.Tensor = outputs["pred_boxes"] | ||
box_scores: torch.Tensor = outputs["box_scores"] | ||
image_size: tuple[int, int] = outputs.image.shape[-2:] | ||
pred_boxes: torch.Tensor = outputs.pred_boxes | ||
box_scores: torch.Tensor = outputs.box_scores | ||
|
||
outputs["anomaly_maps"] = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size) | ||
outputs.anomaly_map = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size) | ||
|
||
if "boxes" in outputs: | ||
true_boxes: list[torch.Tensor] = outputs["boxes"] | ||
outputs["mask"] = boxes_to_masks(true_boxes, image_size) | ||
if outputs.gt_boxes is not None: | ||
true_boxes: list[torch.Tensor] = outputs.gt_boxes | ||
outputs.gt_mask = boxes_to_masks(true_boxes, image_size) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,9 @@ | |
from anomalib.metrics.threshold import BaseThreshold | ||
from anomalib.models import AnomalyModule | ||
from anomalib.utils.types import THRESHOLD | ||
from anomalib.dataclasses import BatchItem | ||
|
||
from dataclasses import asdict | ||
|
||
|
||
class _ThresholdCallback(Callback): | ||
|
@@ -53,7 +56,7 @@ def on_validation_batch_end( | |
) -> None: | ||
del trainer, batch, batch_idx, dataloader_idx # Unused arguments. | ||
if outputs is not None: | ||
self._outputs_to_cpu(outputs) | ||
outputs = self._outputs_to_cpu(outputs) | ||
self._update(pl_module, outputs) | ||
|
||
def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None: | ||
|
@@ -178,16 +181,18 @@ def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: | |
if isinstance(output, dict): | ||
for key, value in output.items(): | ||
output[key] = self._outputs_to_cpu(value) | ||
elif isinstance(output, BatchItem): | ||
output = output.__class__(**self._outputs_to_cpu(asdict(output))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor design comment but can we move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This module is also deprecated with the new design, so you can ignore the changes in this file. So far I hadn't considered device handling in the new design, but adding a |
||
elif isinstance(output, torch.Tensor): | ||
output = output.cpu() | ||
return output | ||
|
||
def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None: | ||
pl_module.image_threshold.cpu() | ||
pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int()) | ||
if "mask" in outputs and "anomaly_maps" in outputs: | ||
pl_module.image_threshold.update(outputs.pred_score, outputs.gt_label.int()) | ||
if outputs.gt_mask is not None and outputs.anomaly_map is not None: | ||
pl_module.pixel_threshold.cpu() | ||
pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int()) | ||
pl_module.pixel_threshold.update(outputs.anomaly_map, outputs.gt_mask.int()) | ||
|
||
def _compute(self, pl_module: AnomalyModule) -> None: | ||
pl_module.image_threshold.compute() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for a future reference... I hope to get rid of this device related stuff, and leave it to Lightning