Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataclasses and post-processing refactor #2098

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
68c5582
use dataclass for model in- and outputs
djdameln May 30, 2024
ddfcd5f
split dataclass in image and video
djdameln May 31, 2024
32e038d
use dataclass in torch inferencer
djdameln Jul 3, 2024
675dd3f
use dataclass in openvino inferencer
djdameln Jul 3, 2024
5779ab7
add post_processor class
djdameln Jul 22, 2024
0662558
remove default metrics from CLI
djdameln Jul 22, 2024
fddbeb1
export post processing
djdameln Jul 23, 2024
e32bd7d
add post processor to patchcore
djdameln Jul 23, 2024
90265e8
use named tuple for inference outputs
djdameln Jul 23, 2024
e3a9c1d
validate and format inputs of PredictBatch
djdameln Jul 24, 2024
89f972c
update torch inference
djdameln Jul 24, 2024
08bdae2
remove base inferencer inheritance
djdameln Jul 24, 2024
2bc76fc
update openvino inference
djdameln Jul 24, 2024
f7c7f9a
fix visualization
djdameln Jul 24, 2024
4160ab3
PredictBatch -> Batch
djdameln Jul 24, 2024
fd9eb24
post processor as callback
djdameln Jul 24, 2024
87facb6
use callback methods to apply post processing
djdameln Jul 25, 2024
2269a78
temporary fix for visualization
djdameln Jul 25, 2024
9652b9f
add DatasetItem class
djdameln Jul 25, 2024
082bbbc
fix pred_score shape and add __len__
djdameln Jul 26, 2024
dbabb20
make batch iterable
djdameln Jul 26, 2024
b190cd3
add in place replace method
djdameln Jul 26, 2024
ed904eb
use dataset items in inference
djdameln Jul 26, 2024
773e54a
dataset_items -> items
djdameln Jul 26, 2024
f8d999a
use namedtuple as torch model outputs
djdameln Jul 31, 2024
67046dd
merge main
djdameln Jul 31, 2024
d00b938
formatting
djdameln Aug 6, 2024
9fb4549
Merge branch 'main' into refactor_outputs
djdameln Aug 7, 2024
86cf632
Merge branch 'main' into refactor_outputs
djdameln Aug 15, 2024
fa3b874
split dataclasses into input/output and image/video
djdameln Aug 19, 2024
2761600
merge input and output classes
djdameln Aug 19, 2024
c650dfc
use init_subclass for attribute checking
djdameln Aug 20, 2024
ced34ca
add descriptor class for validation
djdameln Aug 20, 2024
12cd32d
improve error handling
djdameln Aug 20, 2024
b447cab
DataClassDescriptor -> FieldDescriptor
djdameln Aug 20, 2024
213c2b4
add is_optional method
djdameln Aug 20, 2024
fb80feb
add input validation for torch image and batch
djdameln Aug 21, 2024
d2337a7
use image and video dataclasses in library
djdameln Aug 21, 2024
b53f1f7
add more validation
djdameln Aug 23, 2024
5f16147
add validation
djdameln Aug 26, 2024
9203318
make postprocessor configurable from engine
djdameln Aug 27, 2024
e99d630
fix post processing logic
djdameln Aug 27, 2024
631ba97
Merge branch 'main' into refactor_outputs_separate
djdameln Aug 27, 2024
b750042
fix data tests
djdameln Aug 27, 2024
b37e265
remove detection task type
djdameln Aug 27, 2024
86a365d
fix more tests
djdameln Aug 27, 2024
fcbb628
use separate normalization stats for image and pixel preds
djdameln Aug 27, 2024
0fc3337
add sensitivity parameters to one class pp
djdameln Aug 27, 2024
7ec9dd7
fix utils tests
djdameln Aug 28, 2024
f5a48cd
fix utils tests
djdameln Aug 28, 2024
afaec9b
remove metric serialization test
djdameln Aug 28, 2024
e0a70c8
remove normalization and thresholding args
djdameln Aug 28, 2024
211d9f8
set default post processor in base model
djdameln Aug 28, 2024
eb584eb
remove manual threshold test
djdameln Aug 28, 2024
442c37f
fix remaining unit tests
djdameln Aug 28, 2024
bd59184
add post_processor to CLI args
djdameln Aug 28, 2024
e17eda5
remove old post processing callbacks
djdameln Aug 28, 2024
3140e8b
remove comment
djdameln Aug 28, 2024
af99bed
remove references to old normalization and thresholding callbacks
djdameln Aug 28, 2024
039be2a
remove reshape in openvino inferencer
djdameln Aug 29, 2024
381e638
export lightning model directly
djdameln Aug 29, 2024
daead5b
make collate accessible from dataset
djdameln Aug 29, 2024
987abe5
fix tools integration tests
djdameln Aug 29, 2024
a709c6c
add update method to dataclasses
djdameln Aug 29, 2024
a37fa3b
allow missing pred_score or anomaly_map in post processor
djdameln Aug 29, 2024
14da4fa
fix exportable centercrop conversion
djdameln Aug 30, 2024
beb3b97
fix model tests
djdameln Aug 30, 2024
a9d07db
test all models
djdameln Aug 30, 2024
6bcca36
fix efficient_ad
djdameln Aug 30, 2024
014cb59
post processor as model arg
djdameln Aug 30, 2024
58df063
disable rkde tests
djdameln Aug 30, 2024
25845fb
fix winclip export
djdameln Aug 30, 2024
1defdba
add copyright notice
djdameln Aug 30, 2024
8d60276
add validation for numpy anomaly map
djdameln Sep 2, 2024
0afb6d9
fix getting started notebook
djdameln Sep 2, 2024
a26efb9
remove hardcoded path
djdameln Sep 2, 2024
e7d9852
update dataset notebooks
djdameln Sep 2, 2024
a4bcbfe
update model notebooks
djdameln Sep 2, 2024
085c4aa
Merge branch 'feature/design-simplifications' into refactor_outputs
djdameln Sep 2, 2024
eff1f97
fix logging notebooks
djdameln Sep 2, 2024
40bb4be
fix model notebook
djdameln Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from anomalib import TaskType
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
from anomalib.models import AnomalyModule
from anomalib.dataclasses import BatchItem

from dataclasses import asdict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,7 +124,7 @@ def on_validation_batch_end(
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.

if outputs is not None:
self._outputs_to_device(outputs)
outputs = self._outputs_to_device(outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a future reference... I hope to get rid of this device related stuff, and leave it to Lightning

self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs)

def on_validation_epoch_end(
Expand Down Expand Up @@ -156,7 +159,7 @@ def on_test_batch_end(
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.

if outputs is not None:
self._outputs_to_device(outputs)
outputs = self._outputs_to_device(outputs)
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs)

def on_test_epoch_end(
Expand All @@ -179,15 +182,17 @@ def _update_metrics(
output: STEP_OUTPUT,
) -> None:
image_metric.to(self.device)
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output and "anomaly_maps" in output:
image_metric.update(output.pred_score, output.gt_label.int())
if output.gt_mask is not None and output.anomaly_map is not None:
pixel_metric.to(self.device)
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
pixel_metric.update(torch.squeeze(output.anomaly_map), torch.squeeze(output.gt_mask.int()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pixel_metric.update(torch.squeeze(output.anomaly_map), torch.squeeze(output.gt_mask.int()))
pixel_metric.update(output.anomaly_map.squeeze(), output.gt_mask.squeeze().int())


def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
if isinstance(output, dict):
for key, value in output.items():
output[key] = self._outputs_to_device(value)
elif isinstance(output, BatchItem):
output = output.__class__(**self._outputs_to_device(asdict(output)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be an idea to add a comment here? It might be difficult to understand for some readers

elif isinstance(output, torch.Tensor):
output = output.to(self.device)
return output
Expand Down
26 changes: 13 additions & 13 deletions src/anomalib/callbacks/normalization/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def on_validation_batch_end(
"""Call when the validation batch ends, update the min and max observed values."""
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.

if "anomaly_maps" in outputs:
pl_module.normalization_metrics(outputs["anomaly_maps"])
elif "box_scores" in outputs:
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
elif "pred_scores" in outputs:
pl_module.normalization_metrics(outputs["pred_scores"])
if outputs.anomaly_map is not None:
pl_module.normalization_metrics(outputs.anomaly_map)
djdameln marked this conversation as resolved.
Show resolved Hide resolved
elif outputs.box_scores is not None:
pl_module.normalization_metrics(torch.cat(outputs.box_scores))
elif outputs.pred_score is not None:
pl_module.normalization_metrics(outputs.pred_score)
else:
msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
raise ValueError(msg)
Expand Down Expand Up @@ -99,11 +99,11 @@ def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: A
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
stats = pl_module.normalization_metrics.cpu()
if "pred_scores" in outputs:
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
if "anomaly_maps" in outputs:
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
if "box_scores" in outputs:
outputs["box_scores"] = [
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
if outputs.pred_score is not None:
outputs.pred_score = normalize(outputs.pred_score, image_threshold, stats.min, stats.max)
if outputs.anomaly_map is not None:
outputs.anomaly_map = normalize(outputs.anomaly_map, pixel_threshold, stats.min, stats.max)
if outputs.box_scores is not None:
outputs.box_scores = [
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs.box_scores
]
69 changes: 35 additions & 34 deletions src/anomalib/callbacks/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.models import AnomalyModule
from anomalib.dataclasses import BatchItem


class _PostProcessorCallback(Callback):
Expand All @@ -28,7 +29,7 @@ def on_validation_batch_end(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this callback as LightningModule inherits from ModelHooks which have on_validation_epoch_end and on_validation_batch_end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is deprecated with the new design. I left it here for legacy purposes until we decide how to handle backward compatibility, but you can ignore it for now.

trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
outputs: BatchItem,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
Expand All @@ -42,7 +43,7 @@ def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
outputs: BatchItem,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
Expand All @@ -56,7 +57,7 @@ def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: Any, # noqa: ANN401
outputs: BatchItem, # noqa: ANN401
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
Expand All @@ -67,7 +68,7 @@ def on_predict_batch_end(
self.post_process(trainer, pl_module, outputs)

def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
if isinstance(outputs, dict):
if isinstance(outputs, BatchItem):
self._post_process(outputs)
if trainer.predicting or trainer.testing:
self._compute_scores_and_labels(pl_module, outputs)
Expand All @@ -77,49 +78,49 @@ def _compute_scores_and_labels(
pl_module: AnomalyModule,
outputs: dict[str, Any],
) -> None:
if "pred_scores" in outputs:
outputs["pred_labels"] = outputs["pred_scores"] >= pl_module.image_threshold.value
if "anomaly_maps" in outputs:
outputs["pred_masks"] = outputs["anomaly_maps"] >= pl_module.pixel_threshold.value
if "pred_boxes" not in outputs:
outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes(
outputs["pred_masks"],
outputs["anomaly_maps"],
if outputs.pred_score is not None:
outputs.pred_label = outputs.pred_score >= pl_module.image_threshold.value
if outputs.anomaly_map is not None:
outputs.pred_mask = outputs.anomaly_map >= pl_module.pixel_threshold.value
if outputs.pred_boxes is None:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
outputs.pred_boxes, outputs.box_scores = masks_to_boxes(
outputs.pred_mask,
outputs.anomaly_map,
)
outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]]
outputs.box_labels = [torch.ones(boxes.shape[0]) for boxes in outputs.pred_boxes]
# apply thresholding to boxes
if "box_scores" in outputs and "box_labels" not in outputs:
if outputs.box_scores is not None and outputs.box_labels is None:
# apply threshold to assign normal/anomalous label to boxes
is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs["box_scores"]]
outputs["box_labels"] = [labels.int() for labels in is_anomalous]
is_anomalous = [scores > pl_module.pixel_threshold.value for scores in outputs.box_scores]
outputs.box_labels = [labels.int() for labels in is_anomalous]

@staticmethod
def _post_process(outputs: STEP_OUTPUT) -> None:
def _post_process(outputs: BatchItem) -> None:
"""Compute labels based on model predictions."""
if isinstance(outputs, dict):
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
if isinstance(outputs, BatchItem):
if outputs.pred_score is None and outputs.anomaly_map is not None:
# infer image scores from anomaly maps
outputs["pred_scores"] = (
outputs["anomaly_maps"] # noqa: PD011
.reshape(outputs["anomaly_maps"].shape[0], -1)
outputs.pred_score = (
outputs.anomaly_map # noqa: PD011
.reshape(outputs.anomaly_map.shape[0], -1)
.max(dim=1)
.values
)
elif "pred_scores" not in outputs and "box_scores" in outputs and "label" in outputs:
elif outputs.pred_score is None and outputs.box_score is not None and outputs.gt_label is not None:
# infer image score from bbox confidence scores
outputs["pred_scores"] = torch.zeros_like(outputs["label"]).float()
for idx, (boxes, scores) in enumerate(zip(outputs["pred_boxes"], outputs["box_scores"], strict=True)):
outputs.pred_score = torch.zeros_like(outputs.gt_label).float()
for idx, (boxes, scores) in enumerate(zip(outputs.pred_boxes, outputs.box_scores, strict=True)):
if boxes.numel():
outputs["pred_scores"][idx] = scores.max().item()
outputs.pred_score[idx] = scores.max().item()

if "pred_boxes" in outputs and "anomaly_maps" not in outputs:
if outputs.pred_boxes is not None and outputs.anomaly_map is None:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
# create anomaly maps from bbox predictions for thresholding and evaluation
image_size: tuple[int, int] = outputs["image"].shape[-2:]
pred_boxes: torch.Tensor = outputs["pred_boxes"]
box_scores: torch.Tensor = outputs["box_scores"]
image_size: tuple[int, int] = outputs.image.shape[-2:]
pred_boxes: torch.Tensor = outputs.pred_boxes
box_scores: torch.Tensor = outputs.box_scores

outputs["anomaly_maps"] = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size)
outputs.anomaly_map = boxes_to_anomaly_maps(pred_boxes, box_scores, image_size)

if "boxes" in outputs:
true_boxes: list[torch.Tensor] = outputs["boxes"]
outputs["mask"] = boxes_to_masks(true_boxes, image_size)
if outputs.gt_boxes is not None:
true_boxes: list[torch.Tensor] = outputs.gt_boxes
outputs.gt_mask = boxes_to_masks(true_boxes, image_size)
13 changes: 9 additions & 4 deletions src/anomalib/callbacks/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from anomalib.metrics.threshold import BaseThreshold
from anomalib.models import AnomalyModule
from anomalib.utils.types import THRESHOLD
from anomalib.dataclasses import BatchItem

from dataclasses import asdict


class _ThresholdCallback(Callback):
Expand Down Expand Up @@ -53,7 +56,7 @@ def on_validation_batch_end(
) -> None:
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
if outputs is not None:
self._outputs_to_cpu(outputs)
outputs = self._outputs_to_cpu(outputs)
self._update(pl_module, outputs)

def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
Expand Down Expand Up @@ -178,16 +181,18 @@ def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
if isinstance(output, dict):
for key, value in output.items():
output[key] = self._outputs_to_cpu(value)
elif isinstance(output, BatchItem):
output = output.__class__(**self._outputs_to_cpu(asdict(output)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor design comment but can we move to_cpu to Batch class which returns a copy of itself? This way we will be able to just call output = output.to_cpu().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is also deprecated with the new design, so you can ignore the changes in this file. So far I hadn't considered device handling in the new design, but adding a to_cpu() method could be a good idea!

elif isinstance(output, torch.Tensor):
output = output.cpu()
return output

def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
pl_module.image_threshold.cpu()
pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int())
if "mask" in outputs and "anomaly_maps" in outputs:
pl_module.image_threshold.update(outputs.pred_score, outputs.gt_label.int())
if outputs.gt_mask is not None and outputs.anomaly_map is not None:
pl_module.pixel_threshold.cpu()
pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int())
pl_module.pixel_threshold.update(outputs.anomaly_map, outputs.gt_mask.int())

def _compute(self, pl_module: AnomalyModule) -> None:
pl_module.image_threshold.compute()
Expand Down
25 changes: 15 additions & 10 deletions src/anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label
from anomalib.data.utils.synthetic import SyntheticAnomalyDataset
from anomalib.dataclasses import BatchItem
from dataclasses import asdict

if TYPE_CHECKING:
from pandas import DataFrame
Expand All @@ -26,7 +28,7 @@
logger = logging.getLogger(__name__)


def collate_fn(batch: list) -> dict[str, Any]:
def collate_fn(batch: list[BatchItem]) -> dict[str, Any]:
"""Collate bounding boxes as lists.

Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries.
Expand All @@ -37,16 +39,18 @@ def collate_fn(batch: list) -> dict[str, Any]:
Returns:
dict[str, Any]: Dictionary containing the collated batch information.
"""
elem = batch[0] # sample an element from the batch to check the type.
# convert to list of dicts
batch_dict = [asdict(item) for item in batch]
elem = batch_dict[0] # sample an element from the batch to check the type.
out_dict = {}
if isinstance(elem, dict):
if "boxes" in elem:
# collate boxes as list
out_dict["boxes"] = [item.pop("boxes") for item in batch]
# collate other data normally
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
return out_dict
return default_collate(batch)
# if isinstance(elem, dict):
if "boxes" in elem:
# collate boxes as list
out_dict["boxes"] = [item.pop("boxes") for item in batch_dict]
# collate other data normally
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem if elem[key] is not None})
return batch[0].__class__(**out_dict)
# return default_collate(batch)


class AnomalibDataModule(LightningDataModule, ABC):
Expand Down Expand Up @@ -225,6 +229,7 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
shuffle=True,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
)

def val_dataloader(self) -> EVAL_DATALOADERS:
Expand Down
10 changes: 9 additions & 1 deletion src/anomalib/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchvision.tv_tensors import Mask

from anomalib import TaskType
from anomalib.dataclasses import ImageBatch
from anomalib.data.utils import LabelName, masks_to_boxes, read_image, read_mask

_EXPECTED_COLUMNS_CLASSIFICATION = ["image_path", "split"]
Expand Down Expand Up @@ -189,7 +190,14 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
msg = f"Unknown task type: {self.task}"
raise ValueError(msg)

return item
# return item
return ImageBatch(
image=item["image"],
gt_mask=item["mask"],
gt_label=label_index,
image_path=image_path,
mask_path=mask_path,
)

def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset":
"""Concatenate this dataset with another dataset.
Expand Down
42 changes: 21 additions & 21 deletions src/anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ def _select_targets(self, item: dict[str, Any]) -> dict[str, Any]:
msg = f"Unknown video target frame: {self.target_frame}"
raise ValueError(msg)

if item.get("mask") is not None:
item["mask"] = item["mask"][idx, ...]
if item.get("boxes") is not None:
item["boxes"] = item["boxes"][idx]
if item.get("label") is not None:
item["label"] = item["label"][idx]
if item.get("original_image") is not None:
item["original_image"] = item["original_image"][idx]
if item.get("frames") is not None:
item["frames"] = item["frames"][idx]
if item.mask is not None:
item.mask = item.mask[idx, ...]
if item.gt_boxes is not None:
item.gt_boxes = item.gt_boxes[idx]
if item.gt_label is not None:
item.gt_label = item.gt_label[idx]
if item.original_image is not None:
item.original_image = item.original_image[idx]
if item.frames is not None:
item.frames = item.frames[idx]
return item

def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
Expand All @@ -154,30 +154,30 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
msg = "self.indexer must be an instance of ClipsIndexer."
raise TypeError(msg)
item = self.indexer.get_item(index)
item["image"] = to_dtype_video(video=item["image"], scale=True)
item.image = to_dtype_video(video=item.image, scale=True)
# include the untransformed image for visualization
item["original_image"] = item["image"].to(torch.uint8)
item.original_image = item.image.to(torch.uint8)

# apply transforms
if item.get("mask") is not None:
if item.mask is not None:
if self.transform:
item["image"], item["mask"] = self.transform(item["image"], Mask(item["mask"]))
item["label"] = torch.Tensor([1 in frame for frame in item["mask"]]).int().squeeze(0)
item.image, item.mask = self.transform(item.image, Mask(item.mask))
item.gt_label = torch.Tensor([1 in frame for frame in item.mask]).int().squeeze(0)
if self.task == TaskType.DETECTION:
item["boxes"], _ = masks_to_boxes(item["mask"])
item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
item.gt_boxes, _ = masks_to_boxes(item.mask)
item.gt_boxes = item.gt_boxes[0] if len(item.gt_boxes) == 1 else item.gt_boxes
elif self.transform:
item["image"] = self.transform(item["image"])
item.image = self.transform(item.image)

# squeeze temporal dimensions in case clip length is 1
item["image"] = item["image"].squeeze(0)
item.image = item.image.squeeze(0)

# include only target frame in gt
if self.clip_length_in_frames > 1 and self.target_frame != VideoTargetFrame.ALL:
item = self._select_targets(item)

if item["mask"] is None:
item.pop("mask")
# if item.mask is None:
# item.pop("mask")

return item

Expand Down
Loading