Skip to content

Commit

Permalink
fix cfa test
Browse files Browse the repository at this point in the history
  • Loading branch information
samet-akcay committed Sep 13, 2024
1 parent 45c513d commit d7da3e5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/anomalib/data/validators/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def validate_pred_score(
Raises:
TypeError: If the input is not a torch.Tensor or Sequence[float].
ValueError: If the prediction scores are not 1-dimensional.
ValueError: If the prediction scores have an invalid shape.
Examples:
>>> import torch
Expand All @@ -552,10 +552,10 @@ def validate_pred_score(
if not isinstance(pred_score, torch.Tensor):
msg = f"Prediction scores must be a torch.Tensor or Sequence[float], got {type(pred_score)}."
raise TypeError(msg)
if pred_score.ndim != 1:
msg = f"Prediction scores must be 1-dimensional, got shape {pred_score.shape}."
if pred_score.ndim > 2 or (pred_score.ndim == 2 and pred_score.shape[1] != 1):
msg = f"Prediction scores must be 1-dimensional or have shape (N, 1), got shape {pred_score.shape}."
raise ValueError(msg)
return pred_score.to(torch.float32)
return pred_score.squeeze().to(torch.float32)

@staticmethod
def validate_pred_mask(pred_mask: torch.Tensor | None, batch_size: int) -> Mask | None:
Expand Down

0 comments on commit d7da3e5

Please sign in to comment.