Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BinaryPrecisionRecallCurve for large datasets (>100 million samples) #1309

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ numpy>=1.17.2
torch>=1.8.1
packaging
typing-extensions; python_version < '3.9'
humanfriendly
1 change: 1 addition & 0 deletions requirements/classification.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
humanfriendly
77 changes: 75 additions & 2 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from enum import Enum
from typing import Any, List, Optional, Tuple, Union

import humanfriendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have this as optional

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.precision_recall_curve import (
_adjust_threshold_arg,
_binary_clf_curve,
_binary_precision_recall_curve_arg_validation,
_binary_precision_recall_curve_compute,
_binary_precision_recall_curve_format,
Expand All @@ -41,6 +45,26 @@
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat

_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = 1024**2 # 1MiB,TODO: find a better way to estimate a reasonable minimum


def _budget_bytes_to_nsamples(budget_bytes: int):
# assume that both preds and target ("* 2") will be of size (N, 1) and of type float32 (4 bytes)
return budget_bytes / (2 * 4)


def _validate_memory_budget(budget: int):
if budget <= 0:
raise ValueError("Budget must be larger than 0.")

if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES:
warnings.warn(
f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). "
"The dynamic mode is recommended for bigger samples."
)

return budget
Comment on lines +53 to +63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we only have the thresholds argument i guess all this logic can be moved to the _adjust_threshold_arg function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the multiclass/label cases the estimation (number of samples) <-> (memory consumption) would be different.

I don't have a very strong opinion on this, i will put some more thought on the multi* cases first.



class BinaryPrecisionRecallCurve(Metric):
r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and
Expand Down Expand Up @@ -73,6 +97,7 @@ class BinaryPrecisionRecallCurve(Metric):
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
- If set to a `str`, the value is interpreted as a memory budget and the dynamic mode approach is used.

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Expand Down Expand Up @@ -104,9 +129,25 @@ class BinaryPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

class _ComputationMode(Enum):
"""Internal state of the dynamic mode."""

BINNED = "binned"
NON_BINNED = "non-binned"
NON_BINNED_DYNAMIC = "non-binned-dynamic"
Comment on lines +129 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems weird to me having a class inside a class def?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of rare indeed but i usually do this in such cases where it is strictly only used internally 🤷‍♂️

Should i pop it out?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please


@staticmethod
def _deduce_computation_mode(thresholds: Optional[Union[int, List[float], Tensor, str]]) -> _ComputationMode:
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
elif thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
else:
return BinaryPrecisionRecallCurve._ComputationMode.BINNED
Comment on lines +138 to +143
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
elif thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
else:
return BinaryPrecisionRecallCurve._ComputationMode.BINNED
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
if thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
return BinaryPrecisionRecallCurve._ComputationMode.BINNED


def __init__(
self,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
thresholds: Optional[Union[int, List[float], Tensor, str]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
Expand All @@ -118,8 +159,23 @@ def __init__(
self.ignore_index = ignore_index
self.validate_args = validate_args

self._computation_mode = self._deduce_computation_mode(thresholds)
thresholds = _adjust_threshold_arg(thresholds)
if thresholds is None:

if self._computation_mode == self._ComputationMode.NON_BINNED_DYNAMIC:
self._memory_budget_bytes = _validate_memory_budget(thresholds)
# used after the switch to binned mode
self.register_buffer("thresholds", None)
self.add_state(
"confmat",
default=torch.zeros(_budget_bytes_to_nsamples(self._memory_budget_bytes), 2, 2, dtype=torch.long),
dist_reduce_fx="sum",
)
# they are deleted after the switch to binned mode
self.preds = []
self.target = []

elif thresholds is None:
self.thresholds = thresholds
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
Expand All @@ -140,6 +196,23 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.preds.append(state[0])
self.target.append(state[1])

if self._computation_mode != self._ComputationMode.NON_BINNED_DYNAMIC:
return

all_preds = dim_zero_cat(self.preds)
mem_used = all_preds.element_size() * all_preds.nelement() * 2 # 2 accounts for the target

if mem_used < self._memory_budget_bytes:
return

# switch to binned mode
self.preds, self.target = all_preds, dim_zero_cat(self.target)
_, _, self.thresholds = _binary_clf_curve(self.preds, self.target)
# if the number of thr
self.confmat = _binary_precision_recall_curve_update(self.preds, self.target, self.thresholds)
del self.preds, self.target
self._computation_mode = self._ComputationMode.BINNED

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
if self.thresholds is None:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List, Optional, Sequence, Tuple, Union

import humanfriendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really do not want to introduce any new dependencies.
The conversion from mb and gb seems to be something we can do ourself?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that one coming haha.

So, i put it anyway because it feels like the kind of functionality pruned to silly mistakes while a tiny library like this has it neatly packed in.

I can try to make a minimal version of it based on the library's source code. Is that a better solution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make it conditional, if user already have it, then use it

if module_available("humanfriendly"):
    import humanfriendly
else:
    humanfriendly = None

import torch
from torch import Tensor, tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -81,13 +82,20 @@ def _binary_clf_curve(


def _adjust_threshold_arg(
thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None
) -> Optional[Tensor]:
"""Utility function for converting the threshold arg for list and int to tensor format."""
thresholds: Optional[Union[int, List[float], Tensor, str]] = None, device: Optional[torch.device] = None
) -> Optional[Union[Tensor, int]]:
"""Utility function for converting the threshold arg.

- list and int -> tensor
- None -> None
- str -> int (memory budget) in Mb
"""
if isinstance(thresholds, int):
thresholds = torch.linspace(0, 1, thresholds, device=device)
if isinstance(thresholds, list):
thresholds = torch.tensor(thresholds, device=device)
if isinstance(thresholds, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(thresholds, str):
if isinstance(thresholds, str) and humanfriendly:

thresholds = humanfriendly.parse_size(thresholds, binary=True)
return thresholds


Expand Down