Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure unit tests and fix ruff issues #2306

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions src/anomalib/data/dataclasses/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,27 @@ class NumpyImageItem(_ImageInputFields[str], NumpyItem):
>>> path = item.image_path
"""

def _validate_image(self, image: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_image(image: np.ndarray) -> np.ndarray:
assert image.ndim == 3, f"Expected 3D image, got {image.ndim}D image."
if image.shape[0] == 3:
image = image.transpose(1, 2, 0)
return image

def _validate_gt_label(self, gt_label: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_label(gt_label: np.ndarray) -> np.ndarray:
return gt_label

def _validate_gt_mask(self, gt_mask: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_mask(gt_mask: np.ndarray) -> np.ndarray:
return gt_mask

def _validate_mask_path(self, mask_path: str) -> str:
@staticmethod
def _validate_mask_path(mask_path: str) -> str:
return mask_path

def _validate_anomaly_map(self, anomaly_map: np.ndarray | None) -> np.ndarray | None:
@staticmethod
def _validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None:
if anomaly_map is None:
return None
assert isinstance(anomaly_map, np.ndarray), f"Anomaly map must be a numpy array, got {type(anomaly_map)}."
Expand All @@ -66,21 +71,25 @@ def _validate_anomaly_map(self, anomaly_map: np.ndarray | None) -> np.ndarray |
anomaly_map = anomaly_map.squeeze(0)
return anomaly_map.astype(np.float32)

def _validate_pred_score(self, pred_score: np.ndarray | None) -> np.ndarray | None:
@staticmethod
def _validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None:
if pred_score is None:
return None
if pred_score.ndim == 1:
assert len(pred_score) == 1, f"Expected single value for pred_score, got {len(pred_score)}."
pred_score = pred_score[0]
return pred_score

def _validate_pred_mask(self, pred_mask: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_pred_mask(pred_mask: np.ndarray) -> np.ndarray:
return pred_mask

def _validate_pred_label(self, pred_label: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_pred_label(pred_label: np.ndarray) -> np.ndarray:
return pred_label

def _validate_image_path(self, image_path: str) -> str:
@staticmethod
def _validate_image_path(image_path: str) -> str:
return image_path


Expand Down Expand Up @@ -115,29 +124,38 @@ class NumpyImageBatch(BatchIterateMixin[NumpyImageItem], _ImageInputFields[list[

item_class = NumpyImageItem

def _validate_image(self, image: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_image(image: np.ndarray) -> np.ndarray:
return image

def _validate_gt_label(self, gt_label: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_label(gt_label: np.ndarray) -> np.ndarray:
return gt_label

def _validate_gt_mask(self, gt_mask: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_mask(gt_mask: np.ndarray) -> np.ndarray:
return gt_mask

def _validate_mask_path(self, mask_path: list[str]) -> list[str]:
@staticmethod
def _validate_mask_path(mask_path: list[str]) -> list[str]:
return mask_path

def _validate_anomaly_map(self, anomaly_map: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_anomaly_map(anomaly_map: np.ndarray) -> np.ndarray:
return anomaly_map

def _validate_pred_score(self, pred_score: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_pred_score(pred_score: np.ndarray) -> np.ndarray:
return pred_score

def _validate_pred_mask(self, pred_mask: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_pred_mask(pred_mask: np.ndarray) -> np.ndarray:
return pred_mask

def _validate_pred_label(self, pred_label: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_pred_label(pred_label: np.ndarray) -> np.ndarray:
return pred_label

def _validate_image_path(self, image_path: list[str]) -> list[str]:
@staticmethod
def _validate_image_path(image_path: list[str]) -> list[str]:
return image_path
27 changes: 18 additions & 9 deletions src/anomalib/data/dataclasses/numpy/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ class NumpyVideoItem(_VideoInputFields[np.ndarray, np.ndarray, np.ndarray, str],
for Anomalib's video-based models.
"""

def _validate_image(self, image: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_image(image: np.ndarray) -> np.ndarray:
return image

def _validate_gt_label(self, gt_label: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_label(gt_label: np.ndarray) -> np.ndarray:
return gt_label

def _validate_gt_mask(self, gt_mask: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_mask(gt_mask: np.ndarray) -> np.ndarray:
return gt_mask

def _validate_mask_path(self, mask_path: str) -> str:
@staticmethod
def _validate_mask_path(mask_path: str) -> str:
return mask_path


Expand All @@ -48,17 +52,22 @@ class NumpyVideoBatch(

item_class = NumpyVideoItem

def _validate_image(self, image: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_image(image: np.ndarray) -> np.ndarray:
return image

def _validate_gt_label(self, gt_label: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_label(gt_label: np.ndarray) -> np.ndarray:
return gt_label

def _validate_gt_mask(self, gt_mask: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_gt_mask(gt_mask: np.ndarray) -> np.ndarray:
return gt_mask

def _validate_mask_path(self, mask_path: list[str]) -> list[str]:
@staticmethod
def _validate_mask_path(mask_path: list[str]) -> list[str]:
return mask_path

def _validate_anomaly_map(self, anomaly_map: np.ndarray) -> np.ndarray:
@staticmethod
def _validate_anomaly_map(anomaly_map: np.ndarray) -> np.ndarray:
return anomaly_map
66 changes: 44 additions & 22 deletions src/anomalib/data/dataclasses/torch/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,48 @@ class DepthItem(

numpy_class = NumpyImageItem

def _validate_image(self, image: Image) -> Image:
@staticmethod
def _validate_image(image: Image) -> Image:
return image

def _validate_gt_label(self, gt_label: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_gt_label(gt_label: torch.Tensor) -> torch.Tensor:
return gt_label

def _validate_gt_mask(self, gt_mask: Mask) -> Mask:
@staticmethod
def _validate_gt_mask(gt_mask: Mask) -> Mask:
return gt_mask

def _validate_mask_path(self, mask_path: str) -> str:
@staticmethod
def _validate_mask_path(mask_path: str) -> str:
return mask_path

def _validate_anomaly_map(self, anomaly_map: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_anomaly_map(anomaly_map: torch.Tensor) -> torch.Tensor:
return anomaly_map

def _validate_pred_score(self, pred_score: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_pred_score(pred_score: torch.Tensor) -> torch.Tensor:
return pred_score

def _validate_pred_mask(self, pred_mask: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_pred_mask(pred_mask: torch.Tensor) -> torch.Tensor:
return pred_mask

def _validate_pred_label(self, pred_label: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_pred_label(pred_label: torch.Tensor) -> torch.Tensor:
return pred_label

def _validate_image_path(self, image_path: str) -> str:
@staticmethod
def _validate_image_path(image_path: str) -> str:
return image_path

def _validate_depth_map(self, depth_map: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_depth_map(depth_map: torch.Tensor) -> torch.Tensor:
return depth_map

def _validate_depth_path(self, depth_path: str) -> str:
@staticmethod
def _validate_depth_path(depth_path: str) -> str:
return depth_path


Expand Down Expand Up @@ -110,35 +121,46 @@ class DepthBatch(

item_class = DepthItem

def _validate_image(self, image: Image) -> Image:
@staticmethod
def _validate_image(image: Image) -> Image:
return image

def _validate_gt_label(self, gt_label: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_gt_label(gt_label: torch.Tensor) -> torch.Tensor:
return gt_label

def _validate_gt_mask(self, gt_mask: Mask) -> Mask:
@staticmethod
def _validate_gt_mask(gt_mask: Mask) -> Mask:
return gt_mask

def _validate_mask_path(self, mask_path: list[str]) -> list[str]:
@staticmethod
def _validate_mask_path(mask_path: list[str]) -> list[str]:
return mask_path

def _validate_anomaly_map(self, anomaly_map: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_anomaly_map(anomaly_map: torch.Tensor) -> torch.Tensor:
return anomaly_map

def _validate_pred_score(self, pred_score: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_pred_score(pred_score: torch.Tensor) -> torch.Tensor:
return pred_score

def _validate_pred_mask(self, pred_mask: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_pred_mask(pred_mask: torch.Tensor) -> torch.Tensor:
return pred_mask

def _validate_pred_label(self, pred_label: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_pred_label(pred_label: torch.Tensor) -> torch.Tensor:
return pred_label

def _validate_image_path(self, image_path: list[str]) -> list[str]:
@staticmethod
def _validate_image_path(image_path: list[str]) -> list[str]:
return image_path

def _validate_depth_map(self, depth_map: torch.Tensor) -> torch.Tensor:
@staticmethod
def _validate_depth_map(depth_map: torch.Tensor) -> torch.Tensor:
return depth_map

def _validate_depth_path(self, depth_path: list[str]) -> list[str]:
@staticmethod
def _validate_depth_path(depth_path: list[str]) -> list[str]:
return depth_path
36 changes: 24 additions & 12 deletions src/anomalib/data/dataclasses/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ class ImageItem(

numpy_class = NumpyImageItem

def _validate_image(self, image: torch.Tensor) -> Image:
@staticmethod
def _validate_image(image: torch.Tensor) -> Image:
assert isinstance(image, torch.Tensor), f"Image must be a torch.Tensor, got {type(image)}."
assert image.ndim == 3, f"Image must have shape [C, H, W], got shape {image.shape}."
assert image.shape[0] == 3, f"Image must have 3 channels, got {image.shape[0]}."
return to_dtype_image(image, torch.float32, scale=True)

def _validate_gt_label(self, gt_label: torch.Tensor | int | None) -> torch.Tensor:
@staticmethod
def _validate_gt_label(gt_label: torch.Tensor | int | None) -> torch.Tensor:
if gt_label is None:
return None
if isinstance(gt_label, int):
Expand All @@ -80,7 +82,8 @@ def _validate_gt_label(self, gt_label: torch.Tensor | int | None) -> torch.Tenso
assert not torch.is_floating_point(gt_label), f"Ground truth label must be boolean or integer, got {gt_label}."
return gt_label.bool()

def _validate_gt_mask(self, gt_mask: torch.Tensor | None) -> Mask | None:
@staticmethod
def _validate_gt_mask(gt_mask: torch.Tensor | None) -> Mask | None:
if gt_mask is None:
return None
assert isinstance(gt_mask, torch.Tensor), f"Ground truth mask must be a torch.Tensor, got {type(gt_mask)}."
Expand All @@ -93,12 +96,14 @@ def _validate_gt_mask(self, gt_mask: torch.Tensor | None) -> Mask | None:
gt_mask = gt_mask.squeeze(0)
return Mask(gt_mask, dtype=torch.bool)

def _validate_mask_path(self, mask_path: str | None) -> str | None:
@staticmethod
def _validate_mask_path(mask_path: str | None) -> str | None:
if mask_path is None:
return None
return str(mask_path)

def _validate_anomaly_map(self, anomaly_map: torch.Tensor | None) -> Mask | None:
@staticmethod
def _validate_anomaly_map(anomaly_map: torch.Tensor | None) -> Mask | None:
if anomaly_map is None:
return None
assert isinstance(anomaly_map, torch.Tensor), f"Anomaly map must be a torch.Tensor, got {type(anomaly_map)}."
Expand Down Expand Up @@ -126,7 +131,8 @@ def _validate_pred_score(self, pred_score: torch.Tensor | np.ndarray | None) ->
assert pred_score.ndim == 0, f"Predicted score must be a scalar, got shape {pred_score.shape}."
return pred_score.to(torch.float32)

def _validate_pred_mask(self, pred_mask: torch.Tensor | None) -> Mask | None:
@staticmethod
def _validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None:
if pred_mask is None:
return None
assert isinstance(pred_mask, torch.Tensor), f"Predicted mask must be a torch.Tensor, got {type(pred_mask)}."
Expand All @@ -139,7 +145,8 @@ def _validate_pred_mask(self, pred_mask: torch.Tensor | None) -> Mask | None:
pred_mask = pred_mask.squeeze(0)
return Mask(pred_mask, dtype=torch.bool)

def _validate_pred_label(self, pred_label: torch.Tensor | np.ndarray | None) -> torch.Tensor | None:
@staticmethod
def _validate_pred_label(pred_label: torch.Tensor | np.ndarray | None) -> torch.Tensor | None:
if pred_label is None:
return None
if not isinstance(pred_label, torch.Tensor):
Expand All @@ -152,7 +159,8 @@ def _validate_pred_label(self, pred_label: torch.Tensor | np.ndarray | None) ->
assert pred_label.ndim == 0, f"Predicted label must be a scalar, got shape {pred_label.shape}."
return pred_label.to(torch.bool)

def _validate_image_path(self, image_path: str | None) -> str | None:
@staticmethod
def _validate_image_path(image_path: str | None) -> str | None:
if image_path is None:
return None
return str(image_path)
Expand Down Expand Up @@ -198,7 +206,8 @@ class ImageBatch(
item_class = ImageItem
numpy_class = NumpyImageBatch

def _validate_image(self, image: Image) -> Image:
@staticmethod
def _validate_image(image: Image) -> Image:
assert isinstance(image, torch.Tensor), f"Image must be a torch.Tensor, got {type(image)}."
assert image.ndim in {3, 4}, f"Image must have shape [C, H, W] or [N, C, H, W], got shape {image.shape}."
if image.ndim == 3:
Expand Down Expand Up @@ -286,11 +295,14 @@ def _validate_pred_score(self, pred_score: torch.Tensor | None) -> torch.Tensor
return torch.amax(self.anomaly_map, dim=(-2, -1))
return pred_score

def _validate_pred_mask(self, pred_mask: torch.Tensor) -> torch.Tensor | None:
@staticmethod
def _validate_pred_mask(pred_mask: torch.Tensor) -> torch.Tensor | None:
return pred_mask

def _validate_pred_label(self, pred_label: torch.Tensor) -> torch.Tensor | None:
@staticmethod
def _validate_pred_label(pred_label: torch.Tensor) -> torch.Tensor | None:
return pred_label

def _validate_image_path(self, image_path: list[str]) -> list[str] | None:
@staticmethod
def _validate_image_path(image_path: list[str]) -> list[str] | None:
return image_path
Loading
Loading