diff --git a/requirements/classification.txt b/requirements/classification.txt new file mode 100644 index 00000000000..f5368c4974d --- /dev/null +++ b/requirements/classification.txt @@ -0,0 +1 @@ +humanfriendly diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 0093f44cd70..5a0d78d6244 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -15,14 +15,14 @@ from enum import Enum from typing import Any, List, Optional, Tuple, Union -import numpy as np -import scipy +import humanfriendly import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _adjust_threshold_arg, + _binary_clf_curve, _binary_precision_recall_curve_arg_validation, _binary_precision_recall_curve_compute, _binary_precision_recall_curve_format, @@ -45,60 +45,27 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -DYNAMIC_THRESHOLDS_MODE_STR = "dynamic" -_DYNAMIC_THRESHOLDS_NBINS = 10**4 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 10**4 -_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS -_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 10**7 +_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = 1024**2 # 1MiB,TODO: find a better way to estimate a reasonable minimum -def _validate_budget(budget: Optional[int]) -> int: - if budget is None: - raise ValueError("Budget must be specified when using dynamic thresholds mode.") +def _budget_bytes_to_nsamples(budget_bytes: int): + # assume that both preds and target ("* 2") will be of size (N, 1) and of type float32 (4 bytes) + return budget_bytes / (2 * 4) + +def _validate_memory_budget(budget: int): if budget <= 0: raise ValueError("Budget must be larger than 0.") - if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: - warnings.warn( - f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. " - "This mode is recommended for a number of samples larger than " - f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples." - ) - - if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: + if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: warnings.warn( - f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples " - f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples " - "to estimate the thresholds." + f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). " + "The dynamic mode is recommended for bigger samples." ) return budget -def _estimate_threhsholds(preds: Tensor) -> Tensor: - global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN - npreds = preds.numel() - - # sample from the predictions if there are too many (computation of mquantiles can be very slow) - if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION: - indices = torch.randperm(npreds, device=preds.device) - indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION] - preds_q_estimation = preds[indices] - else: - preds_q_estimation = preds - - preds_q_estimation = preds_q_estimation.cpu() - - thresholds = scipy.stats.mstats.mquantiles( - preds_q_estimation, # it has to be on the CPU - prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS), - ) - - # remove the min/max so lower/higher values will go to the first/last "bin" - return thresholds[1:-1] - - class BinaryPrecisionRecallCurve(Metric): r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. @@ -130,6 +97,7 @@ class BinaryPrecisionRecallCurve(Metric): - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + - If set to a `str`, the value is interpreted as a memory budget and the dynamic mode approach is used. validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. @@ -161,19 +129,27 @@ class BinaryPrecisionRecallCurve(Metric): higher_is_better: Optional[bool] = None full_state_update: bool = False - class _DynamicModeState(Enum): + class _ComputationMode(Enum): """Internal state of the dynamic mode.""" - NONE = "none" - NON_BINNED = "non_binned" BINNED = "binned" + NON_BINNED = "non-binned" + NON_BINNED_DYNAMIC = "non-binned-dynamic" + + @staticmethod + def _deduce_computation_mode(thresholds: Optional[Union[int, List[float], Tensor, str]]) -> _ComputationMode: + if isinstance(thresholds, str): + return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC + elif thresholds is None: + return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED + else: + return BinaryPrecisionRecallCurve._ComputationMode.BINNED def __init__( self, - thresholds: Optional[Union[int, List[float], Tensor]] = None, + thresholds: Optional[Union[int, List[float], Tensor, str]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, - budget: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -183,30 +159,21 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args + self._computation_mode = self._deduce_computation_mode(thresholds) thresholds = _adjust_threshold_arg(thresholds) - if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR: - raise ValueError(f"Invalid thresholds mode '{thresholds}'.") - - self._dynamic_mode_state = self._DynamicModeState.NONE - - if thresholds == DYNAMIC_THRESHOLDS_MODE_STR: - - self.budget = _validate_budget(budget) - self._dynamic_mode_state = self._DynamicModeState.NON_BINNED - - # they are deleted after the switch to binned mode - self.preds = [] - self.target = [] - + if self._computation_mode == self._ComputationMode.NON_BINNED_DYNAMIC: + self._memory_budget_bytes = _validate_memory_budget(thresholds) # used after the switch to binned mode self.register_buffer("thresholds", None) self.add_state( - # "-1" here compenstes the lack min/max removed (see comments in update() to understand) "confmat", - default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long), + default=torch.zeros(_budget_bytes_to_nsamples(self._memory_budget_bytes), 2, 2, dtype=torch.long), dist_reduce_fx="sum", ) + # they are deleted after the switch to binned mode + self.preds = [] + self.target = [] elif thresholds is None: self.thresholds = thresholds @@ -229,19 +196,22 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.preds.append(state[0]) self.target.append(state[1]) - if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED: + if self._computation_mode != self._ComputationMode.NON_BINNED_DYNAMIC: return all_preds = dim_zero_cat(self.preds) + mem_used = all_preds.element_size() * all_preds.nelement() * 2 # 2 accounts for the target - if all_preds.numel() < self.budget: + if mem_used < self._memory_budget_bytes: return # switch to binned mode - self.thresholds = _estimate_threhsholds(all_preds) - self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds) + self.preds, self.target = all_preds, dim_zero_cat(self.target) + _, _, self.thresholds = _binary_clf_curve(self.preds, self.target) + # if the number of thr + self.confmat = _binary_precision_recall_curve_update(self.preds, self.target, self.thresholds) del self.preds, self.target - self._dynamic_mode_state = self._DynamicModeState.BINNED + self._computation_mode = self._ComputationMode.BINNED def compute(self) -> Tuple[Tensor, Tensor, Tensor]: if self.thresholds is None: diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index a505898f040..792ab5c306d 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -14,6 +14,7 @@ from typing import List, Optional, Sequence, Tuple, Union +import humandfriendly import torch from torch import Tensor, tensor from torch.nn import functional as F @@ -81,13 +82,20 @@ def _binary_clf_curve( def _adjust_threshold_arg( - thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None -) -> Optional[Tensor]: - """Utility function for converting the threshold arg for list and int to tensor format.""" + thresholds: Optional[Union[int, List[float], Tensor, str]] = None, device: Optional[torch.device] = None +) -> Optional[Union[Tensor, int]]: + """Utility function for converting the threshold arg. + + - list and int -> tensor + - None -> None + - str -> int (memory budget) in Mb + """ if isinstance(thresholds, int): thresholds = torch.linspace(0, 1, thresholds, device=device) if isinstance(thresholds, list): thresholds = torch.tensor(thresholds, device=device) + if isinstance(thresholds, str): + thresholds = humandfriendly.parse_size(thresholds, binary=True) return thresholds