Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcbertoldo committed Nov 3, 2022
1 parent 1e5f130 commit b98edf7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 73 deletions.
1 change: 1 addition & 0 deletions requirements/classification.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
humanfriendly
110 changes: 40 additions & 70 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from enum import Enum
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import scipy
import humanfriendly
import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.precision_recall_curve import (
_adjust_threshold_arg,
_binary_clf_curve,
_binary_precision_recall_curve_arg_validation,
_binary_precision_recall_curve_compute,
_binary_precision_recall_curve_format,
Expand All @@ -45,60 +45,27 @@
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat

DYNAMIC_THRESHOLDS_MODE_STR = "dynamic"
_DYNAMIC_THRESHOLDS_NBINS = 10**4
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN = 10**4
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN * _DYNAMIC_THRESHOLDS_NBINS
_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION = 10**7
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = 1024**2 # 1MiB,TODO: find a better way to estimate a reasonable minimum


def _validate_budget(budget: Optional[int]) -> int:
if budget is None:
raise ValueError("Budget must be specified when using dynamic thresholds mode.")
def _budget_bytes_to_nsamples(budget_bytes: int):
# assume that both preds and target ("* 2") will be of size (N, 1) and of type float32 (4 bytes)
return budget_bytes / (2 * 4)


def _validate_memory_budget(budget: int):
if budget <= 0:
raise ValueError("Budget must be larger than 0.")

if budget <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES:
warnings.warn(
f"Budget is small ({budget/10**6:.3f} million samples) to use dynamic thresholds mode. "
"This mode is recommended for a number of samples larger than "
f"{_DYNAMIC_THRESHOLDS_MIN_NSAMPLES/10**6:.3f} million samples."
)

if budget > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION:
if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES:
warnings.warn(
f"Budget is {budget/10**6:.3f} million samples but the dynamic thresholds mode samples "
f"{_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION/10**6:.3f} million samples "
"to estimate the thresholds."
f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). "
"The dynamic mode is recommended for bigger samples."
)

return budget


def _estimate_threhsholds(preds: Tensor) -> Tensor:
global _DYNAMIC_THRESHOLDS_NBINS, _DYNAMIC_THRESHOLDS_MIN_NSAMPLES_PERBIN
npreds = preds.numel()

# sample from the predictions if there are too many (computation of mquantiles can be very slow)
if npreds > _DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION:
indices = torch.randperm(npreds, device=preds.device)
indices = indices[:_DYNAMIC_THRESHOLDS_MAX_NSAMPLES_QUANTILE_ESTIMATION]
preds_q_estimation = preds[indices]
else:
preds_q_estimation = preds

preds_q_estimation = preds_q_estimation.cpu()

thresholds = scipy.stats.mstats.mquantiles(
preds_q_estimation, # it has to be on the CPU
prob=np.linspace(0, 1, _DYNAMIC_THRESHOLDS_NBINS),
)

# remove the min/max so lower/higher values will go to the first/last "bin"
return thresholds[1:-1]


class BinaryPrecisionRecallCurve(Metric):
r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and
recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
Expand Down Expand Up @@ -130,6 +97,7 @@ class BinaryPrecisionRecallCurve(Metric):
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
- If set to a `str`, the value is interpreted as a memory budget and the dynamic mode approach is used.
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Expand Down Expand Up @@ -161,19 +129,27 @@ class BinaryPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

class _DynamicModeState(Enum):
class _ComputationMode(Enum):
"""Internal state of the dynamic mode."""

NONE = "none"
NON_BINNED = "non_binned"
BINNED = "binned"
NON_BINNED = "non-binned"
NON_BINNED_DYNAMIC = "non-binned-dynamic"

@staticmethod
def _deduce_computation_mode(thresholds: Optional[Union[int, List[float], Tensor, str]]) -> _ComputationMode:
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
elif thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
else:
return BinaryPrecisionRecallCurve._ComputationMode.BINNED

def __init__(
self,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
thresholds: Optional[Union[int, List[float], Tensor, str]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
budget: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -183,30 +159,21 @@ def __init__(
self.ignore_index = ignore_index
self.validate_args = validate_args

self._computation_mode = self._deduce_computation_mode(thresholds)
thresholds = _adjust_threshold_arg(thresholds)

if isinstance(thresholds, str) and thresholds != DYNAMIC_THRESHOLDS_MODE_STR:
raise ValueError(f"Invalid thresholds mode '{thresholds}'.")

self._dynamic_mode_state = self._DynamicModeState.NONE

if thresholds == DYNAMIC_THRESHOLDS_MODE_STR:

self.budget = _validate_budget(budget)
self._dynamic_mode_state = self._DynamicModeState.NON_BINNED

# they are deleted after the switch to binned mode
self.preds = []
self.target = []

if self._computation_mode == self._ComputationMode.NON_BINNED_DYNAMIC:
self._memory_budget_bytes = _validate_memory_budget(thresholds)
# used after the switch to binned mode
self.register_buffer("thresholds", None)
self.add_state(
# "-1" here compenstes the lack min/max removed (see comments in update() to understand)
"confmat",
default=torch.zeros(_DYNAMIC_THRESHOLDS_NBINS - 1, 2, 2, dtype=torch.long),
default=torch.zeros(_budget_bytes_to_nsamples(self._memory_budget_bytes), 2, 2, dtype=torch.long),
dist_reduce_fx="sum",
)
# they are deleted after the switch to binned mode
self.preds = []
self.target = []

elif thresholds is None:
self.thresholds = thresholds
Expand All @@ -229,19 +196,22 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.preds.append(state[0])
self.target.append(state[1])

if self._dynamic_mode_state != self._DynamicModeState.NON_BINNED:
if self._computation_mode != self._ComputationMode.NON_BINNED_DYNAMIC:
return

all_preds = dim_zero_cat(self.preds)
mem_used = all_preds.element_size() * all_preds.nelement() * 2 # 2 accounts for the target

if all_preds.numel() < self.budget:
if mem_used < self._memory_budget_bytes:
return

# switch to binned mode
self.thresholds = _estimate_threhsholds(all_preds)
self.confmat = _binary_precision_recall_curve_update(all_preds, dim_zero_cat(self.target), self.thresholds)
self.preds, self.target = all_preds, dim_zero_cat(self.target)
_, _, self.thresholds = _binary_clf_curve(self.preds, self.target)
# if the number of thr
self.confmat = _binary_precision_recall_curve_update(self.preds, self.target, self.thresholds)
del self.preds, self.target
self._dynamic_mode_state = self._DynamicModeState.BINNED
self._computation_mode = self._ComputationMode.BINNED

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
if self.thresholds is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List, Optional, Sequence, Tuple, Union

import humandfriendly
import torch
from torch import Tensor, tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -81,13 +82,20 @@ def _binary_clf_curve(


def _adjust_threshold_arg(
thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None
) -> Optional[Tensor]:
"""Utility function for converting the threshold arg for list and int to tensor format."""
thresholds: Optional[Union[int, List[float], Tensor, str]] = None, device: Optional[torch.device] = None
) -> Optional[Union[Tensor, int]]:
"""Utility function for converting the threshold arg.
- list and int -> tensor
- None -> None
- str -> int (memory budget) in Mb
"""
if isinstance(thresholds, int):
thresholds = torch.linspace(0, 1, thresholds, device=device)
if isinstance(thresholds, list):
thresholds = torch.tensor(thresholds, device=device)
if isinstance(thresholds, str):
thresholds = humandfriendly.parse_size(thresholds, binary=True)
return thresholds


Expand Down

0 comments on commit b98edf7

Please sign in to comment.