Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update AUROC with weight input #94

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions tests/metrics/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,71 @@ def _test_auroc_class_with_input(
input: torch.Tensor,
target: torch.Tensor,
num_tasks: int = 1,
weight: Optional[torch.Tensor] = None,
compute_result: Optional[torch.Tensor] = None,
use_fbgemm: Optional[bool] = False,
) -> None:
input_tensors = input.reshape(-1, 1)
target_tensors = target.reshape(-1, 1)
if compute_result is None:
compute_result = torch.tensor(roc_auc_score(target_tensors, input_tensors))
weight_tensors = weight.reshape(-1, 1) if weight is not None else None

self.run_class_implementation_tests(
metric=BinaryAUROC(num_tasks=num_tasks, use_fbgemm=use_fbgemm),
state_names={"inputs", "targets"},
update_kwargs={
"input": input,
"target": target,
},
compute_result=compute_result,
test_devices=["cuda"] if use_fbgemm else None,
)
if compute_result is None:
compute_result = (
torch.tensor(
roc_auc_score(
target_tensors, input_tensors, sample_weight=weight_tensors
)
)
if weight_tensors is not None
else torch.tensor(roc_auc_score(target_tensors, input_tensors))
)
if weight is not None:
self.run_class_implementation_tests(
metric=BinaryAUROC(num_tasks=num_tasks, use_fbgemm=use_fbgemm),
state_names={"inputs", "targets", "weights"},
update_kwargs={
"input": input,
"target": target,
"weight": weight,
},
compute_result=compute_result,
test_devices=["cuda"] if use_fbgemm else None,
)
else:
self.run_class_implementation_tests(
metric=BinaryAUROC(num_tasks=num_tasks, use_fbgemm=use_fbgemm),
state_names={"inputs", "targets", "weights"},
update_kwargs={
"input": input,
"target": target,
},
compute_result=compute_result,
test_devices=["cuda"] if use_fbgemm else None,
)

def _test_auroc_class_set(self, use_fbgemm: Optional[bool] = False) -> None:
input = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE))
self._test_auroc_class_with_input(input, target, use_fbgemm=use_fbgemm)
# fbgemm version does not support weight in AUROC calculation
weight = (
torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) if use_fbgemm is False else None
)
self._test_auroc_class_with_input(
input, target, num_tasks=1, weight=weight, use_fbgemm=use_fbgemm
)

if not use_fbgemm:
# Skip this test for use_fbgemm because FBGEMM AUC is an
# approximation of AUC. It can give a significantly different
# result if input data is highly redundant
input = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE))
target = torch.randint(high=2, size=(NUM_TOTAL_UPDATES, BATCH_SIZE))
weight = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
self._test_auroc_class_with_input(
input,
target,
num_tasks=1,
weight=weight,
use_fbgemm=use_fbgemm,
)

Expand Down Expand Up @@ -101,26 +133,36 @@ def test_auroc_class_update_input_shape_different(self) -> None:
torch.randint(high=num_classes, size=(2,)),
torch.randint(high=num_classes, size=(5,)),
]

update_weight = [
torch.rand(5),
torch.rand(8),
torch.rand(2),
torch.rand(5),
]

compute_result = torch.tensor(
roc_auc_score(
torch.cat(update_target, dim=0),
torch.cat(update_input, dim=0),
sample_weight=torch.cat(update_weight, dim=0),
),
)

self.run_class_implementation_tests(
metric=BinaryAUROC(),
state_names={"inputs", "targets"},
state_names={"inputs", "targets", "weights"},
update_kwargs={
"input": update_input,
"target": update_target,
"weight": update_weight,
},
compute_result=compute_result,
num_total_updates=4,
num_processes=2,
)

def test_auroc_class_invalid_input(self) -> None:
def test_binary_auroc_class_invalid_input(self) -> None:
metric = BinaryAUROC()
with self.assertRaisesRegex(
ValueError,
Expand All @@ -129,6 +171,13 @@ def test_auroc_class_invalid_input(self) -> None:
):
metric.update(torch.rand(4), torch.rand(3))

with self.assertRaisesRegex(
ValueError,
"The `weight` and `target` should have the same shape, "
r"got shapes torch.Size\(\[3\]\) and torch.Size\(\[4\]\).",
):
metric.update(torch.rand(4), torch.rand(4), weight=torch.rand(3))

with self.assertRaisesRegex(
ValueError,
"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,",
Expand Down Expand Up @@ -256,7 +305,7 @@ def test_auroc_class_update_input_shape_different(self) -> None:
num_processes=2,
)

def test_auroc_class_invalid_input(self) -> None:
def test_multiclass_auroc_class_invalid_input(self) -> None:
with self.assertRaisesRegex(
ValueError, "`average` was not in the allowed value of .*, got micro."
):
Expand Down
29 changes: 24 additions & 5 deletions tests/metrics/functional/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,21 @@ def _test_auroc_with_input(
input: torch.Tensor,
target: torch.Tensor,
num_tasks: int = 1,
weight: Optional[torch.Tensor] = None,
compute_result: Optional[torch.Tensor] = None,
use_fbgemm: Optional[bool] = False,
) -> None:
if compute_result is None:
compute_result = torch.tensor(roc_auc_score(target, input))
compute_result = (
torch.tensor(roc_auc_score(target, input))
if weight is None
else torch.tensor(roc_auc_score(target, input, sample_weight=weight))
)
if torch.cuda.is_available():
my_compute_result = binary_auroc(
input.to(device="cuda"),
target.to(device="cuda"),
weight=weight if weight is None else weight.to(device="cuda"),
num_tasks=num_tasks,
use_fbgemm=use_fbgemm,
)
Expand All @@ -48,7 +54,13 @@ def _test_auroc_with_input(
def _test_auroc_set(self, use_fbgemm: Optional[bool] = False) -> None:
input = torch.tensor([1, 1, 0, 0])
target = torch.tensor([1, 0, 1, 0])
weight = torch.tensor([0.2, 0.2, 1.0, 1.0], dtype=torch.float64)
self._test_auroc_with_input(input, target, use_fbgemm=use_fbgemm)
if use_fbgemm is False:
# TODO: use_fbgemm = True will fail the situation with weight input
self._test_auroc_with_input(
input, target, weight=weight, use_fbgemm=use_fbgemm
)

input = torch.rand(BATCH_SIZE)
target = torch.randint(high=2, size=(BATCH_SIZE,))
Expand All @@ -59,8 +71,8 @@ def _test_auroc_set(self, use_fbgemm: Optional[bool] = False) -> None:
self._test_auroc_with_input(
input,
target,
2,
torch.tensor([0.7500, 0.2500], dtype=torch.float64),
num_tasks=2,
compute_result=torch.tensor([0.7500, 0.2500], dtype=torch.float64),
use_fbgemm=use_fbgemm,
)

Expand All @@ -73,14 +85,21 @@ def test_auroc_fbgemm(self) -> None:
def test_auroc_base(self) -> None:
self._test_auroc_set(use_fbgemm=False)

def test_auroc_invalid_input(self) -> None:
def test_binary_auroc_invalid_input(self) -> None:
with self.assertRaisesRegex(
ValueError,
"The `input` and `target` should have the same shape, "
r"got shapes torch.Size\(\[4\]\) and torch.Size\(\[3\]\).",
):
binary_auroc(torch.rand(4), torch.rand(3))

with self.assertRaisesRegex(
ValueError,
"The `weight` and `target` should have the same shape, "
r"got shapes torch.Size\(\[3\]\) and torch.Size\(\[4\]\).",
):
binary_auroc(torch.rand(4), torch.rand(4), weight=torch.rand(3))

with self.assertRaisesRegex(
ValueError,
"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,",
Expand Down Expand Up @@ -147,7 +166,7 @@ def test_auroc_average_options(self) -> None:
my_compute_result = multiclass_auroc(input, target, num_classes=3, average=None)
torch.testing.assert_close(my_compute_result, expected_compute_result)

def test_auroc_invalid_input(self) -> None:
def test_multiclass_auroc_invalid_input(self) -> None:
with self.assertRaisesRegex(
ValueError, "`average` was not in the allowed value of .*, got micro."
):
Expand Down
Loading