From 4c445f0e9e1e818ddcdd94409738bb34553af43d Mon Sep 17 00:00:00 2001 From: "Ning Li (Seattle)" Date: Tue, 6 Dec 2022 14:28:04 -0800 Subject: [PATCH] update AUROC with weight input (#94) Summary: Pull Request resolved: https://github.com/pytorch/torcheval/pull/94 update AUROC metric with weight as input. Reviewed By: ananthsub Differential Revision: D41561499 fbshipit-source-id: 6495444013e0cfa0bf6743062fb768c26b5a43e3 --- tests/metrics/classification/test_auroc.py | 81 +++++++++++++---- .../functional/classification/test_auroc.py | 29 ++++-- tests/metrics/window/test_auroc.py | 88 ++++++++++++++++--- torcheval/metrics/classification/auroc.py | 16 +++- .../functional/classification/auroc.py | 28 ++++-- torcheval/metrics/window/auroc.py | 33 ++++++- 6 files changed, 234 insertions(+), 41 deletions(-) diff --git a/tests/metrics/classification/test_auroc.py b/tests/metrics/classification/test_auroc.py index de4d3822..9d26c17b 100644 --- a/tests/metrics/classification/test_auroc.py +++ b/tests/metrics/classification/test_auroc.py @@ -27,29 +27,58 @@ def _test_auroc_class_with_input( input: torch.Tensor, target: torch.Tensor, num_tasks: int = 1, + weight: Optional[torch.Tensor] = None, compute_result: Optional[torch.Tensor] = None, use_fbgemm: Optional[bool] = False, ) -> None: input_tensors = input.reshape(-1, 1) target_tensors = target.reshape(-1, 1) - if compute_result is None: - compute_result = torch.tensor(roc_auc_score(target_tensors, input_tensors)) + weight_tensors = weight.reshape(-1, 1) if weight is not None else None - self.run_class_implementation_tests( - metric=BinaryAUROC(num_tasks=num_tasks, use_fbgemm=use_fbgemm), - state_names={"inputs", "targets"}, - update_kwargs={ - "input": input, - "target": target, - }, - compute_result=compute_result, - test_devices=["cuda"] if use_fbgemm else None, - ) + if compute_result is None: + compute_result = ( + torch.tensor( + roc_auc_score( + target_tensors, input_tensors, sample_weight=weight_tensors + ) + ) + if weight_tensors is not None + else torch.tensor(roc_auc_score(target_tensors, input_tensors)) + ) + if weight is not None: + self.run_class_implementation_tests( + metric=BinaryAUROC(num_tasks=num_tasks, use_fbgemm=use_fbgemm), + state_names={"inputs", "targets", "weights"}, + update_kwargs={ + "input": input, + "target": target, + "weight": weight, + }, + compute_result=compute_result, + test_devices=["cuda"] if use_fbgemm else None, + ) + else: + self.run_class_implementation_tests( + metric=BinaryAUROC(num_tasks=num_tasks, use_fbgemm=use_fbgemm), + state_names={"inputs", "targets", "weights"}, + update_kwargs={ + "input": input, + "target": target, + }, + compute_result=compute_result, + test_devices=["cuda"] if use_fbgemm else None, + ) def _test_auroc_class_set(self, use_fbgemm: Optional[bool] = False) -> None: input = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE)) - self._test_auroc_class_with_input(input, target, use_fbgemm=use_fbgemm) + # fbgemm version does not support weight in AUROC calculation + weight = ( + torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) if use_fbgemm is False else None + ) + self._test_auroc_class_with_input( + input, target, num_tasks=1, weight=weight, use_fbgemm=use_fbgemm + ) if not use_fbgemm: # Skip this test for use_fbgemm because FBGEMM AUC is an @@ -57,9 +86,12 @@ def _test_auroc_class_set(self, use_fbgemm: Optional[bool] = False) -> None: # result if input data is highly redundant input = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE)) target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE)) + weight = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) self._test_auroc_class_with_input( input, target, + num_tasks=1, + weight=weight, use_fbgemm=use_fbgemm, ) @@ -101,26 +133,36 @@ def test_auroc_class_update_input_shape_different(self) -> None: torch.randint(high=num_classes, size=(2,)), torch.randint(high=num_classes, size=(5,)), ] + + update_weight = [ + torch.rand(5), + torch.rand(8), + torch.rand(2), + torch.rand(5), + ] + compute_result = torch.tensor( roc_auc_score( torch.cat(update_target, dim=0), torch.cat(update_input, dim=0), + sample_weight=torch.cat(update_weight, dim=0), ), ) self.run_class_implementation_tests( metric=BinaryAUROC(), - state_names={"inputs", "targets"}, + state_names={"inputs", "targets", "weights"}, update_kwargs={ "input": update_input, "target": update_target, + "weight": update_weight, }, compute_result=compute_result, num_total_updates=4, num_processes=2, ) - def test_auroc_class_invalid_input(self) -> None: + def test_binary_auroc_class_invalid_input(self) -> None: metric = BinaryAUROC() with self.assertRaisesRegex( ValueError, @@ -129,6 +171,13 @@ def test_auroc_class_invalid_input(self) -> None: ): metric.update(torch.rand(4), torch.rand(3)) + with self.assertRaisesRegex( + ValueError, + "The `weight` and `target` should have the same shape, " + r"got shapes torch.Size\(\[3\]\) and torch.Size\(\[4\]\).", + ): + metric.update(torch.rand(4), torch.rand(4), weight=torch.rand(3)) + with self.assertRaisesRegex( ValueError, "`num_tasks = 1`, `input` is expected to be one-dimensional tensor,", @@ -256,7 +305,7 @@ def test_auroc_class_update_input_shape_different(self) -> None: num_processes=2, ) - def test_auroc_class_invalid_input(self) -> None: + def test_multiclass_auroc_class_invalid_input(self) -> None: with self.assertRaisesRegex( ValueError, "`average` was not in the allowed value of .*, got micro." ): diff --git a/tests/metrics/functional/classification/test_auroc.py b/tests/metrics/functional/classification/test_auroc.py index 252358ce..23fcf979 100644 --- a/tests/metrics/functional/classification/test_auroc.py +++ b/tests/metrics/functional/classification/test_auroc.py @@ -21,15 +21,21 @@ def _test_auroc_with_input( input: torch.Tensor, target: torch.Tensor, num_tasks: int = 1, + weight: Optional[torch.Tensor] = None, compute_result: Optional[torch.Tensor] = None, use_fbgemm: Optional[bool] = False, ) -> None: if compute_result is None: - compute_result = torch.tensor(roc_auc_score(target, input)) + compute_result = ( + torch.tensor(roc_auc_score(target, input)) + if weight is None + else torch.tensor(roc_auc_score(target, input, sample_weight=weight)) + ) if torch.cuda.is_available(): my_compute_result = binary_auroc( input.to(device="cuda"), target.to(device="cuda"), + weight=weight if weight is None else weight.to(device="cuda"), num_tasks=num_tasks, use_fbgemm=use_fbgemm, ) @@ -48,7 +54,13 @@ def _test_auroc_with_input( def _test_auroc_set(self, use_fbgemm: Optional[bool] = False) -> None: input = torch.tensor([1, 1, 0, 0]) target = torch.tensor([1, 0, 1, 0]) + weight = torch.tensor([0.2, 0.2, 1.0, 1.0], dtype=torch.float64) self._test_auroc_with_input(input, target, use_fbgemm=use_fbgemm) + if use_fbgemm is False: + # TODO: use_fbgemm = True will fail the situation with weight input + self._test_auroc_with_input( + input, target, weight=weight, use_fbgemm=use_fbgemm + ) input = torch.rand(BATCH_SIZE) target = torch.randint(high=2, size=(BATCH_SIZE,)) @@ -59,8 +71,8 @@ def _test_auroc_set(self, use_fbgemm: Optional[bool] = False) -> None: self._test_auroc_with_input( input, target, - 2, - torch.tensor([0.7500, 0.2500], dtype=torch.float64), + num_tasks=2, + compute_result=torch.tensor([0.7500, 0.2500], dtype=torch.float64), use_fbgemm=use_fbgemm, ) @@ -73,7 +85,7 @@ def test_auroc_fbgemm(self) -> None: def test_auroc_base(self) -> None: self._test_auroc_set(use_fbgemm=False) - def test_auroc_invalid_input(self) -> None: + def test_binary_auroc_invalid_input(self) -> None: with self.assertRaisesRegex( ValueError, "The `input` and `target` should have the same shape, " @@ -81,6 +93,13 @@ def test_auroc_invalid_input(self) -> None: ): binary_auroc(torch.rand(4), torch.rand(3)) + with self.assertRaisesRegex( + ValueError, + "The `weight` and `target` should have the same shape, " + r"got shapes torch.Size\(\[3\]\) and torch.Size\(\[4\]\).", + ): + binary_auroc(torch.rand(4), torch.rand(4), weight=torch.rand(3)) + with self.assertRaisesRegex( ValueError, "`num_tasks = 1`, `input` is expected to be one-dimensional tensor,", @@ -147,7 +166,7 @@ def test_auroc_average_options(self) -> None: my_compute_result = multiclass_auroc(input, target, num_classes=3, average=None) torch.testing.assert_close(my_compute_result, expected_compute_result) - def test_auroc_invalid_input(self) -> None: + def test_multiclass_auroc_invalid_input(self) -> None: with self.assertRaisesRegex( ValueError, "`average` was not in the allowed value of .*, got micro." ): diff --git a/tests/metrics/window/test_auroc.py b/tests/metrics/window/test_auroc.py index 25924d39..326f0522 100644 --- a/tests/metrics/window/test_auroc.py +++ b/tests/metrics/window/test_auroc.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import torch from sklearn.metrics import roc_auc_score @@ -21,14 +23,23 @@ def _test_auroc_class_with_input( self, input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.Tensor] = None, max_num_samples: int = 100, ) -> None: input_tensors = input.reshape(-1)[-max_num_samples:] target_tensors = target.reshape(-1)[-max_num_samples:] - compute_result = torch.tensor(roc_auc_score(target_tensors, input_tensors)) + weight_tensors = ( + weight.reshape(-1)[-max_num_samples:] if weight is not None else None + ) + compute_result = torch.tensor( + roc_auc_score(target_tensors, input_tensors, sample_weight=weight_tensors) + if weight_tensors is not None + else roc_auc_score(target_tensors, input_tensors) + ) input_tensors = input.reshape(-1) target_tensors = target.reshape(-1) + weight_tensors = weight.reshape(-1) if weight is not None else None input_tensors = torch.cat( [ @@ -46,15 +57,30 @@ def _test_auroc_class_with_input( target_tensors[118:], ] ) + weight_tensors = ( + torch.cat( + [ + weight_tensors[22:32], + weight_tensors[54:64], + weight_tensors[86:96], + weight_tensors[118:], + ] + ) + if weight_tensors is not None + else None + ) merge_compute_result = torch.tensor( - roc_auc_score(target_tensors, input_tensors) + roc_auc_score(target_tensors, input_tensors, sample_weight=weight_tensors) + if weight is not None + else roc_auc_score(target_tensors, input_tensors) ) self.run_class_implementation_tests( metric=WindowedBinaryAUROC(max_num_samples=max_num_samples), - state_names={"inputs", "targets"}, + state_names={"inputs", "targets", "weights"}, update_kwargs={ "input": input, "target": target, + "weight": weight, }, compute_result=compute_result, merge_and_compute_result=merge_compute_result, @@ -66,17 +92,20 @@ def _test_auroc_class_with_input( def test_auroc_class_base(self) -> None: input = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE)) - self._test_auroc_class_with_input(input, target, 10) + weight = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) + self._test_auroc_class_with_input(input, target, weight, 10) input = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE)) target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE)) - self._test_auroc_class_with_input(input, target, 10) + weight = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) + self._test_auroc_class_with_input(input, target, weight, 10) def test_auroc_class_multiple_tasks(self) -> None: num_tasks = 2 max_num_samples = 10 input = torch.rand(NUM_TOTAL_UPDATES, num_tasks, BATCH_SIZE) target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, num_tasks, BATCH_SIZE)) + weight = torch.rand(NUM_TOTAL_UPDATES, num_tasks, BATCH_SIZE) input_tensors = input.permute(1, 0, 2).reshape(num_tasks, -1)[ :, -max_num_samples: @@ -84,10 +113,16 @@ def test_auroc_class_multiple_tasks(self) -> None: target_tensors = target.permute(1, 0, 2).reshape(num_tasks, -1)[ :, -max_num_samples: ] - compute_result = binary_auroc(input_tensors, target_tensors, num_tasks=2) + weight_tensors = weight.permute(1, 0, 2).reshape(num_tasks, -1)[ + :, -max_num_samples: + ] + compute_result = binary_auroc( + input_tensors, target_tensors, num_tasks=2, weight=weight_tensors + ) input_tensors = input.permute(1, 0, 2).reshape(num_tasks, -1) target_tensors = target.permute(1, 0, 2).reshape(num_tasks, -1) + weight_tensors = weight.permute(1, 0, 2).reshape(num_tasks, -1) input_tensors = torch.cat( [ @@ -107,16 +142,28 @@ def test_auroc_class_multiple_tasks(self) -> None: ], dim=1, ) - merge_compute_result = binary_auroc(input_tensors, target_tensors, num_tasks=2) + weight_tensors = torch.cat( + [ + weight_tensors[:, 22:32], + weight_tensors[:, 54:64], + weight_tensors[:, 86:96], + weight_tensors[:, 118:], + ], + dim=1, + ) + merge_compute_result = binary_auroc( + input_tensors, target_tensors, num_tasks=2, weight=weight_tensors + ) self.run_class_implementation_tests( metric=WindowedBinaryAUROC( num_tasks=num_tasks, max_num_samples=max_num_samples ), - state_names={"inputs", "targets"}, + state_names={"inputs", "targets", "weights"}, update_kwargs={ "input": input, "target": target, + "weight": weight, }, compute_result=compute_result, merge_and_compute_result=merge_compute_result, @@ -140,23 +187,37 @@ def test_auroc_class_update_input_shape_different(self) -> None: torch.randint(high=num_classes, size=(2,)), torch.randint(high=num_classes, size=(5,)), ] + + update_weight = [ + torch.rand(5), + torch.rand(8), + torch.rand(2), + torch.rand(5), + ] + compute_result = binary_auroc( torch.cat(update_input, dim=0)[-6:], torch.cat(update_target, dim=0)[-6:], + weight=torch.cat(update_weight, dim=0)[-6:], ) update_target_tensors = torch.cat(update_target, dim=0) update_input_tensors = torch.cat(update_input, dim=0) + update_weight_tensors = torch.cat(update_weight, dim=0) merge_compute_result = binary_auroc( torch.cat([update_input_tensors[7:13], update_input_tensors[14:]], dim=0), torch.cat([update_target_tensors[7:13], update_target_tensors[14:]], dim=0), + weight=torch.cat( + [update_weight_tensors[7:13], update_weight_tensors[14:]], dim=0 + ), ) self.run_class_implementation_tests( metric=WindowedBinaryAUROC(max_num_samples=6), - state_names={"inputs", "targets"}, + state_names={"inputs", "targets", "weights"}, update_kwargs={ "input": update_input, "target": update_target, + "weight": update_weight, }, compute_result=compute_result, merge_and_compute_result=merge_compute_result, @@ -167,7 +228,7 @@ def test_auroc_class_update_input_shape_different(self) -> None: test_merge_with_one_update=False, ) - def test_auroc_class_invalid_input(self) -> None: + def test_binary_auroc_class_invalid_input(self) -> None: metric = WindowedBinaryAUROC() with self.assertRaisesRegex( ValueError, @@ -176,6 +237,13 @@ def test_auroc_class_invalid_input(self) -> None: ): metric.update(torch.rand(4), torch.rand(3)) + with self.assertRaisesRegex( + ValueError, + "The `weight` and `target` should have the same shape, " + r"got shapes torch.Size\(\[3\]\) and torch.Size\(\[4\]\).", + ): + metric.update(torch.rand(4), torch.rand(4), weight=torch.rand(3)) + with self.assertRaisesRegex( ValueError, "`num_tasks = 1`, `input` is expected to be one-dimensional tensor,", diff --git a/torcheval/metrics/classification/auroc.py b/torcheval/metrics/classification/auroc.py index 5fea9634..d0c38bc9 100644 --- a/torcheval/metrics/classification/auroc.py +++ b/torcheval/metrics/classification/auroc.py @@ -85,6 +85,7 @@ def __init__( self.num_tasks = num_tasks self._add_state("inputs", []) self._add_state("targets", []) + self._add_state("weights", []) self.use_fbgemm = use_fbgemm @torch.inference_mode() @@ -93,6 +94,7 @@ def update( self: TAUROC, input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.Tensor] = None, ) -> TAUROC: """ Update states with the ground truth labels and predictions. @@ -101,10 +103,14 @@ def update( input (Tensor): Tensor of label predictions It should be predicted label, probabilities or logits with shape of (num_tasks, n_sample) or (n_sample, ). target (Tensor): Tensor of ground truth labels with shape of (num_tasks, n_sample) or (n_sample, ). + weight (Tensor): Optional. A manual rescaling weight to match input tensor shape (num_tasks, num_samples) or (n_sample, ). """ - _binary_auroc_update_input_check(input, target, self.num_tasks) + if weight is None: + weight = torch.ones_like(input, dtype=torch.double) + _binary_auroc_update_input_check(input, target, self.num_tasks, weight) self.inputs.append(input) self.targets.append(target) + self.weights.append(weight) return self @torch.inference_mode() @@ -119,7 +125,10 @@ def compute( Tensor: The return value of AUROC for each task (num_tasks,). """ return _binary_auroc_compute( - torch.cat(self.inputs, -1), torch.cat(self.targets, -1), self.use_fbgemm + torch.cat(self.inputs, -1), + torch.cat(self.targets, -1), + torch.cat(self.weights, -1), + self.use_fbgemm, ) @torch.inference_mode() @@ -128,8 +137,10 @@ def merge_state(self: TAUROC, metrics: Iterable[TAUROC]) -> TAUROC: if metric.inputs: metric_inputs = torch.cat(metric.inputs, -1).to(self.device) metric_targets = torch.cat(metric.targets, -1).to(self.device) + metric_weights = torch.cat(metric.weights, -1).to(self.device) self.inputs.append(metric_inputs) self.targets.append(metric_targets) + self.weights.append(metric_weights) return self @torch.inference_mode() @@ -137,6 +148,7 @@ def _prepare_for_merge_state(self: TAUROC) -> None: if self.inputs and self.targets: self.inputs = [torch.cat(self.inputs, -1)] self.targets = [torch.cat(self.targets, -1)] + self.weights = [torch.cat(self.weights, -1)] class MulticlassAUROC(Metric[torch.Tensor]): diff --git a/torcheval/metrics/functional/classification/auroc.py b/torcheval/metrics/functional/classification/auroc.py index 41247bba..18f3fe6b 100644 --- a/torcheval/metrics/functional/classification/auroc.py +++ b/torcheval/metrics/functional/classification/auroc.py @@ -27,6 +27,7 @@ def binary_auroc( target: torch.Tensor, *, num_tasks: int = 1, + weight: Optional[torch.Tensor] = None, use_fbgemm: Optional[bool] = False, ) -> torch.Tensor: """ @@ -39,7 +40,8 @@ def binary_auroc( target (Tensor): Tensor of ground truth labels with shape of (num_tasks, n_sample) or (n_sample, ). num_tasks (int): Number of tasks that need BinaryAUROC calculation. Default value is 1. BinaryAUROC for each task will be calculated independently. - use_fbgemm (bool): If set to True, use ``fbgemm_gpu.metrics.auc`` (a + weight (Tensor): Optional. A manual rescaling weight to match input tensor shape (num_tasks, num_samples) or (n_sample, ). + use_fbgemm (bool): Optional. If set to True, use ``fbgemm_gpu.metrics.auc`` (a hand fused kernel). FBGEMM AUC is an approximation of AUC. It does not mask data in case that input values are redundant. For the highly redundant input case, FBGEMM AUC can give a significantly @@ -64,8 +66,8 @@ def binary_auroc( >>> binary_auroc(input, target, num_tasks=2) tensor([0.7500, 0.6667]) """ - _binary_auroc_update_input_check(input, target, num_tasks) - return _binary_auroc_compute(input, target, use_fbgemm) + _binary_auroc_update_input_check(input, target, num_tasks, weight) + return _binary_auroc_compute(input, target, weight, use_fbgemm) @torch.inference_mode() @@ -112,12 +114,18 @@ def multiclass_auroc( def _binary_auroc_compute_jit( input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: threshold, indices = input.sort(descending=True) mask = F.pad(threshold.diff(dim=-1) != 0, [0, 1], value=1.0) sorted_target = torch.gather(target, -1, indices) - cum_tp_before_pad = sorted_target.cumsum(-1) - cum_fp_before_pad = (1 - sorted_target).cumsum(-1) + sorted_weight = ( + torch.tensor(1.0, device=target.device) + if weight is None + else torch.gather(weight, -1, indices) + ) + cum_tp_before_pad = (sorted_weight * sorted_target).cumsum(-1) + cum_fp_before_pad = (sorted_weight * (1 - sorted_target)).cumsum(-1) shifted_mask = mask.sum(-1, keepdim=True) >= torch.arange( mask.size(-1), 0, -1, device=target.device @@ -145,6 +153,7 @@ def _binary_auroc_compute_jit( def _binary_auroc_compute( input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.Tensor] = None, use_fbgemm: Optional[bool] = False, ) -> torch.Tensor: if use_fbgemm: @@ -161,19 +170,26 @@ def _binary_auroc_compute( else: return auroc else: - return _binary_auroc_compute_jit(input, target) + return _binary_auroc_compute_jit(input, target, weight) def _binary_auroc_update_input_check( input: torch.Tensor, target: torch.Tensor, num_tasks: int, + weight: Optional[torch.Tensor] = None, ) -> None: if input.shape != target.shape: raise ValueError( "The `input` and `target` should have the same shape, " f"got shapes {input.shape} and {target.shape}." ) + if weight is not None and weight.shape != target.shape: + raise ValueError( + "The `weight` and `target` should have the same shape, " + f"got shapes {weight.shape} and {target.shape}." + ) + if num_tasks == 1: if len(input.shape) > 1: raise ValueError( diff --git a/torcheval/metrics/window/auroc.py b/torcheval/metrics/window/auroc.py index a0aa0987..a1913dc0 100644 --- a/torcheval/metrics/window/auroc.py +++ b/torcheval/metrics/window/auroc.py @@ -81,6 +81,10 @@ def __init__( "targets", torch.zeros(self.num_tasks, self.max_num_samples, device=self.device), ) + self._add_state( + "weights", + torch.zeros(self.num_tasks, self.max_num_samples, device=self.device), + ) @torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any @@ -88,6 +92,7 @@ def update( self: TAUROC, input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.Tensor] = None, ) -> TAUROC: """ Update states with the ground truth labels and predictions. @@ -99,14 +104,18 @@ def update( target (Tensor): Tensor of ground truth labels with shape of (num_samples, ) or or (num_tasks, num_samples). """ - _binary_auroc_update_input_check(input, target, self.num_tasks) + if weight is None: + weight = torch.ones_like(input, dtype=torch.double) + _binary_auroc_update_input_check(input, target, self.num_tasks, weight) if input.ndim == 1: input = input.reshape(1, -1) target = target.reshape(1, -1) + weight = weight.reshape(1, -1) # If input size is greater than or equal to window size, replace it with the last max_num_samples size of input. if input.shape[1] >= self.max_num_samples: self.inputs.copy_(input[:, -self.max_num_samples :].detach()) self.targets.copy_(target[:, -self.max_num_samples :].detach()) + self.weights.copy_(weight[:, -self.max_num_samples :].detach()) self.next_inserted = 0 else: rest_window_size = self.max_num_samples - self.next_inserted @@ -118,6 +127,9 @@ def update( self.targets[ :, self.next_inserted : self.next_inserted + input.shape[1] ] = target.detach() + self.weights[ + :, self.next_inserted : self.next_inserted + input.shape[1] + ] = weight.detach() self.next_inserted += input.shape[1] else: # Otherwise, replace with the first half and the second half of input respectively. @@ -128,6 +140,9 @@ def update( self.targets[ :, self.next_inserted : self.next_inserted + rest_window_size ] = target[:, :rest_window_size].detach() + self.weights[ + :, self.next_inserted : self.next_inserted + rest_window_size + ] = weight[:, :rest_window_size].detach() # Put the second half of input to the front rest_window_size = input.shape[1] - rest_window_size @@ -137,6 +152,9 @@ def update( self.targets[:, :rest_window_size] = target[ :, -rest_window_size: ].detach() + self.weights[:, :rest_window_size] = weight[ + :, -rest_window_size: + ].detach() self.next_inserted = rest_window_size self.next_inserted %= self.max_num_samples @@ -159,9 +177,12 @@ def compute( return _binary_auroc_compute( self.inputs[:, : self.next_inserted].squeeze(), self.targets[:, : self.next_inserted].squeeze(), + self.weights[:, : self.next_inserted].squeeze(), ) else: - return _binary_auroc_compute(self.inputs.squeeze(), self.targets.squeeze()) + return _binary_auroc_compute( + self.inputs.squeeze(), self.targets.squeeze(), self.weights.squeeze() + ) @torch.inference_mode() def merge_state(self: TAUROC, metrics: Iterable[TAUROC]) -> TAUROC: @@ -179,6 +200,7 @@ def merge_state(self: TAUROC, metrics: Iterable[TAUROC]) -> TAUROC: merge_max_num_samples += metric.max_num_samples cur_inputs = self.inputs cur_targets = self.targets + cur_weights = self.weights self.inputs = torch.zeros( self.num_tasks, merge_max_num_samples, @@ -189,15 +211,22 @@ def merge_state(self: TAUROC, metrics: Iterable[TAUROC]) -> TAUROC: merge_max_num_samples, device=self.device, ) + self.weights = torch.zeros( + self.num_tasks, + merge_max_num_samples, + device=self.device, + ) cur_size = min(self.total_samples, self.max_num_samples) self.inputs[:, :cur_size] = cur_inputs[:, :cur_size] self.targets[:, :cur_size] = cur_targets[:, :cur_size] + self.weights[:, :cur_size] = cur_weights[:, :cur_size] idx = cur_size for metric in metrics: cur_size = min(metric.total_samples, metric.max_num_samples) self.inputs[:, idx : idx + cur_size] = metric.inputs[:, :cur_size] self.targets[:, idx : idx + cur_size] = metric.targets[:, :cur_size] + self.weights[:, idx : idx + cur_size] = metric.weights[:, :cur_size] self.total_samples += metric.total_samples idx += cur_size