Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

[XAI] hot-fix of error in Detection XAI support #99

Merged
merged 5 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mpa/cls/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _infer(self, cfg, dump_features=False, dump_saliency_map=False):
assert len(eval_predictions) == len(feature_vectors) == len(saliency_maps), \
'Number of elements should be the same, however, number of outputs are ' \
f"{len(eval_predictions)}, {len(feature_vectors)}, and {len(saliency_maps)}"

outputs = dict(
eval_predictions=eval_predictions,
feature_vectors=feature_vectors,
Expand Down
5 changes: 5 additions & 0 deletions mpa/det/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,16 @@ def dummy_dump_features_hook(mod, inp, out):

if isinstance(dataset, ImageTilingDataset):
feature_vectors = [feature_vectors[i] for i in range(dataset.num_samples)]
saliency_maps = [saliency_maps[i] for i in range(dataset.num_samples)]
if not dataset.merged_results:
eval_predictions = dataset.merge(eval_predictions)
else:
eval_predictions = dataset.merged_results

assert len(eval_predictions) == len(feature_vectors) == len(saliency_maps), \
'Number of elements should be the same, however, number of outputs are ' \
f"{len(eval_predictions)}, {len(feature_vectors)}, and {len(saliency_maps)}"

outputs = dict(
classes=target_classes,
detections=eval_predictions,
Expand Down
49 changes: 30 additions & 19 deletions mpa/modules/models/detectors/sam_detector_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from mmdet.models.detectors import BaseDetector
from mmdet.models.detectors import BaseDetector, TwoStageDetector
from mmdet.utils.deployment.export_helpers import get_feature_vector
from mmdet.integration.nncf.utils import no_nncf_trace
from mpa.modules.hooks.auxiliary_hooks import DetSaliencyMapHook
Expand All @@ -19,22 +19,33 @@ def train_step(self, data, optimizer, **kwargs):
self.current_batch = data
return super().train_step(data, optimizer, **kwargs)

def simple_test(self, img, img_metas, rescale=False, postprocess=True):
x = self.extract_feat(img)
outs = self.bbox_head(x)
with no_nncf_trace():
bbox_results = \
self.bbox_head.get_bboxes(*outs, img_metas, self.test_cfg, False)
if torch.onnx.is_in_onnx_export():
feature_vector = get_feature_vector(x)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
feature = feature_vector, saliency_map
return bbox_results[0], feature
def simple_test(self,
img,
img_metas,
proposals=None,
rescale=False,
postprocess=True):
"""
Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
"""
if isinstance(self, TwoStageDetector):
return super().simple_test(img, img_metas, proposals, rescale, postprocess)
else:
x = self.extract_feat(img)
outs = self.bbox_head(x)
with no_nncf_trace():
bbox_results = \
self.bbox_head.get_bboxes(*outs, img_metas, self.test_cfg, False)
if torch.onnx.is_in_onnx_export():
feature_vector = get_feature_vector(x)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
feature = feature_vector, saliency_map
return bbox_results[0], feature

if postprocess:
bbox_results = [
self.postprocess(det_bboxes, det_labels, None, img_metas, rescale=rescale)
for det_bboxes, det_labels in bbox_results
]
return bbox_results
if postprocess:
bbox_results = [
self.postprocess(det_bboxes, det_labels, None, img_metas, rescale=rescale)
for det_bboxes, det_labels in bbox_results
]
return bbox_results